Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 164 additions & 78 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import hashlib
import itertools
import os
import pickle
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, List, Optional
Expand Down Expand Up @@ -27,32 +29,24 @@


@dataclass
class ReqMeta:
"""
Request Blocks layout:
----------------------------------------------------------------------------------------------------
| local_computed_block(HBM hit) | external_computed_block(external hit) | new_block(need to dump) |
----------------------------------------------------------------------------------------------------
| hbm_hit_block_num | LOAD | DUMP |
----------------------------------------------------------------------------------------------------
| total_hit_block_num |
----------------------------------------------------------------------------------------------------
| scheduled_block_num |
"""

class RequestMeta:
ucm_block_ids: list[str] = field(default_factory=list)
# vLLM block ids
vllm_block_ids: list[int] = field(default_factory=list)
hbm_hit_block_num: int = 0
# local_computed_block + external_computed_block
total_hit_block_num: int = 0
# local_computed_block + external_computed_block + new_block
scheduled_block_num: int = 0


@dataclass
class RequestDispatchMeta:
load_block_ids: tuple[
list[str], list[int]
] # [0] mean ucm_block_ids, [1] means vllm_block_ids
dump_block_ids: tuple[list[str], list[int]]


@dataclass
class UCMConnectorMetadata(KVConnectorMetadata):
request_meta: dict[str, ReqMeta] = field(default_factory=dict)
request_meta: dict[str, RequestDispatchMeta] = field(default_factory=dict)


class RequestHasher:
Expand Down Expand Up @@ -118,8 +112,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):

self.request_hasher = RequestHasher()

# save block info, avoid hash request twice
self.request_meta: dict[str, ReqMeta] = {}
# save block info, avoid hash request twice, and track them until request finished
self.requests_meta: dict[str, RequestMeta] = {}

# TODO use yaml
if (
Expand Down Expand Up @@ -207,7 +201,7 @@ def get_num_new_matched_tokens(
if external_hit_tokens == request.num_prompt_tokens:
external_hit_tokens -= 1

self.request_meta[request.request_id] = ReqMeta(
self.requests_meta[request.request_id] = RequestMeta(
ucm_block_ids=ucm_block_ids,
hbm_hit_block_num=hbm_hit_block_num,
total_hit_block_num=total_hit_block_num,
Expand All @@ -220,52 +214,106 @@ def update_state_after_alloc(
):
pass

def _generate_dispatch_meta(
self,
req_meta: RequestMeta,
new_tokens: int,
vllm_block_ids: list[int],
need_load: bool = True,
) -> RequestDispatchMeta:
"""
Request Blocks layout:
----------------------------------------------------------------------------------------------------
| local_computed_block(HBM hit) | external_computed_block(external hit) | new_block(need to dump) |
----------------------------------------------------------------------------------------------------
| hbm_hit_block_num | LOAD | new_blocks_num |
----------------------------------------------------------------------------------------------------
| total_hit_block_num |
----------------------------------------------------------------------------------------------------
| scheduled_block_num |
"""

new_blocks_num = new_tokens // self.block_size
hbm_hit_block_num = req_meta.hbm_hit_block_num
total_hit_block_num = req_meta.total_hit_block_num
scheduled_block_num = total_hit_block_num + new_blocks_num
ucm_block_ids = req_meta.ucm_block_ids

dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num]
if need_load:
dump_vllm_block_ids = vllm_block_ids[
total_hit_block_num:scheduled_block_num
]
else:
dump_vllm_block_ids = vllm_block_ids

# after this round, req_meta will be updated
req_meta.total_hit_block_num = scheduled_block_num

load_ucm_block_ids, load_vllm_block_ids = [], []
if need_load:
load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num]
load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num]

return RequestDispatchMeta(
(load_ucm_block_ids, load_vllm_block_ids),
(dump_ucm_block_ids, dump_vllm_block_ids),
)

def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
requests_dispatch_meta = {}
# for new request, we need to load and dump
for request in scheduler_output.scheduled_new_reqs:
request_id, vllm_block_ids = request.req_id, request.block_ids[0]
req_meta = self.requests_meta.get(request_id)
if req_meta:
requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
req_meta,
scheduler_output.num_scheduled_tokens[request_id],
vllm_block_ids,
)

scheduled_request = [
(req.req_id, req.block_ids) for req in scheduler_output.scheduled_new_reqs
]

def get_requests(cached_request_data):
# 0.9.1
if isinstance(cached_request_data, list):
return [
(
request_data.req_id,
request_data.new_block_ids,
)
for request_data in cached_request_data
]
# for cached request, there are 3 situation:
# 1. chunked prefill: we only need dump
# 2. resumed: we need to handle like new request
# 3. TODO decode stage: nothing happened
scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs
if not isinstance(scheduled_cached_reqs, list):
# >= 0.9.2
else:
return [
(
req_id,
cached_request_data.new_block_ids[i],
for i, request_id in enumerate(scheduled_cached_reqs.req_ids):
if scheduler_output.num_scheduled_tokens[request_id] == 1:
# decode stage
continue
req_meta = self.requests_meta.get(request_id)
if req_meta:
requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
req_meta,
scheduler_output.num_scheduled_tokens[request_id],
scheduled_cached_reqs.new_block_ids[i][0],
scheduled_cached_reqs.resumed_from_preemption[i],
)
else:
for request in scheduled_cached_reqs:
request_id = request.request_id
if scheduler_output.num_scheduled_tokens[request_id] == 1:
# decode stage
continue
req_meta = self.requests_meta.get(request_id)
if req_meta:
requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
req_meta,
scheduler_output.num_scheduled_tokens[request_id],
request.new_block_ids[0],
request.resumed_from_preemption,
)
for i, req_id in enumerate(cached_request_data.req_ids)
]

scheduled_request.extend(get_requests(scheduler_output.scheduled_cached_reqs))

for request_id, vllm_block_ids in scheduled_request:
req_meta = self.request_meta.get(request_id)
if req_meta:
# we only save scheduled tokens in this step
new_tokens = scheduler_output.num_scheduled_tokens[request_id]
new_blocks_num = new_tokens // self.block_size
req_meta.scheduled_block_num = (
req_meta.total_hit_block_num + new_blocks_num
)
req_meta.vllm_block_ids = vllm_block_ids[0]

# we need to clear self.request_meta
request_meta = self.request_meta
self.request_meta = {}
# clear finished request
for request_id in scheduler_output.finished_req_ids:
self.requests_meta.pop(request_id, None)

return UCMConnectorMetadata(request_meta)
return UCMConnectorMetadata(requests_dispatch_meta)

def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"):
if len(self.kv_caches) > 0:
Expand Down Expand Up @@ -426,16 +474,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
request_to_task: dict[str, Optional[Task]] = {}
req_to_layer = {}
for request_id, request in metadata.request_meta.items():
hbm_hit_block_num = request.hbm_hit_block_num
total_hit_block_num = request.total_hit_block_num
if hbm_hit_block_num == total_hit_block_num:
# no external hit blocks
if len(request.load_block_ids[0]) == 0:
continue

vllm_block_ids = request.vllm_block_ids[
hbm_hit_block_num:total_hit_block_num
]
ucm_block_ids = request.ucm_block_ids[hbm_hit_block_num:total_hit_block_num]
ucm_block_ids, vllm_block_ids = request.load_block_ids
if self.load_only_first_rank:
can_load = self.rank == 0
task, layer_to_tensors, total_block_num = (
Expand Down Expand Up @@ -483,15 +525,12 @@ def wait_for_save(self) -> None:
request_to_task: dict[str, Task] = {}
request_to_blocks: dict[str, list[str]] = {}
for request_id, request in metadata.request_meta.items():
total_hit_block_num = request.total_hit_block_num
scheduled_block_num = request.scheduled_block_num
if scheduled_block_num == total_hit_block_num:
# no need to save block
if len(request.dump_block_ids[0]) == 0:
continue

ucm_block_ids = request.ucm_block_ids[total_hit_block_num:]
ucm_block_ids, vllm_block_ids = request.dump_block_ids
rets = self.store.create(ucm_block_ids)
end = total_hit_block_num
end = 0
for i, ret in enumerate(rets):
if ret != 0:
logger.error(
Expand All @@ -500,10 +539,8 @@ def wait_for_save(self) -> None:
break
end += 1

if end == total_hit_block_num:
continue
ucm_block_ids = ucm_block_ids[total_hit_block_num:end]
vllm_block_ids = request.vllm_block_ids[total_hit_block_num:end]
ucm_block_ids = ucm_block_ids[:end]
vllm_block_ids = vllm_block_ids[:end]
request_to_task[request_id] = self._generate_task(
vllm_block_ids, ucm_block_ids, self.store.dump
)
Expand Down Expand Up @@ -585,11 +622,60 @@ def get_finished(
raise NotImplementedError


class UCMMockConnector(UCMDirectConnector):
"""
This Connector can control hit ratio, for example: if your hit ratio is 100%,
you can set "hit_ratio" by config or env_vars, then get_num_new_matched_tokens()
will reduce hit_tokens under the hit_ratio you set.
"""

def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config, role)
self._hit_ratio = float(
self._vllm_config.kv_transfer_config.kv_connector_extra_config["hit_ratio"]
)
logger.info(f"hit_ratio: {self._hit_ratio}")

def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
hit_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
expect_hit_tokens = int(self._hit_ratio * request.num_prompt_tokens)
if hit_tokens <= expect_hit_tokens:
return hit_tokens, False
expect_hit_block_num = expect_hit_tokens // self.block_size
request_meta = self.requests_meta[request.request_id]
request_meta.total_hit_block_num = expect_hit_block_num
request_meta.hbm_hit_block_num = min(
expect_hit_block_num, request_meta.hbm_hit_block_num
)

logger.info(
"Hijacked By MockConnector,"
f"request_id: {request.request_id}, "
f"total_blocks_num: {len(request_meta.ucm_block_ids)}, "
f"hit hbm: {request_meta.hbm_hit_block_num}, "
f"hit external: {request_meta.total_hit_block_num - request_meta.hbm_hit_block_num}"
)

return expect_hit_block_num * self.block_size, False


class UCMConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self.connector: KVConnectorBase_V1
# TODO new conn by config
self.connector = UCMDirectConnector(vllm_config, role)
if (
self._vllm_config.kv_transfer_config is not None
and "hit_ratio"
in self._vllm_config.kv_transfer_config.kv_connector_extra_config
):
self.connector = UCMMockConnector(vllm_config, role)
else:
self.connector = UCMDirectConnector(vllm_config, role)

def get_num_new_matched_tokens(
self,
Expand Down