diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 3c1ade22..31fc6ac9 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -1,4 +1,6 @@ import hashlib +import itertools +import os import pickle from dataclasses import dataclass, field from typing import TYPE_CHECKING, Callable, List, Optional @@ -27,32 +29,24 @@ @dataclass -class ReqMeta: - """ - Request Blocks layout: - ---------------------------------------------------------------------------------------------------- - | local_computed_block(HBM hit) | external_computed_block(external hit) | new_block(need to dump) | - ---------------------------------------------------------------------------------------------------- - | hbm_hit_block_num | LOAD | DUMP | - ---------------------------------------------------------------------------------------------------- - | total_hit_block_num | - ---------------------------------------------------------------------------------------------------- - | scheduled_block_num | - """ - +class RequestMeta: ucm_block_ids: list[str] = field(default_factory=list) - # vLLM block ids - vllm_block_ids: list[int] = field(default_factory=list) hbm_hit_block_num: int = 0 # local_computed_block + external_computed_block total_hit_block_num: int = 0 - # local_computed_block + external_computed_block + new_block - scheduled_block_num: int = 0 + + +@dataclass +class RequestDispatchMeta: + load_block_ids: tuple[ + list[str], list[int] + ] # [0] mean ucm_block_ids, [1] means vllm_block_ids + dump_block_ids: tuple[list[str], list[int]] @dataclass class UCMConnectorMetadata(KVConnectorMetadata): - request_meta: dict[str, ReqMeta] = field(default_factory=dict) + request_meta: dict[str, RequestDispatchMeta] = field(default_factory=dict) class RequestHasher: @@ -118,8 +112,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.request_hasher = RequestHasher() - # save block info, avoid hash request twice - self.request_meta: dict[str, ReqMeta] = {} + # save block info, avoid hash request twice, and track them until request finished + self.requests_meta: dict[str, RequestMeta] = {} # TODO use yaml if ( @@ -207,7 +201,7 @@ def get_num_new_matched_tokens( if external_hit_tokens == request.num_prompt_tokens: external_hit_tokens -= 1 - self.request_meta[request.request_id] = ReqMeta( + self.requests_meta[request.request_id] = RequestMeta( ucm_block_ids=ucm_block_ids, hbm_hit_block_num=hbm_hit_block_num, total_hit_block_num=total_hit_block_num, @@ -220,52 +214,106 @@ def update_state_after_alloc( ): pass + def _generate_dispatch_meta( + self, + req_meta: RequestMeta, + new_tokens: int, + vllm_block_ids: list[int], + need_load: bool = True, + ) -> RequestDispatchMeta: + """ + Request Blocks layout: + ---------------------------------------------------------------------------------------------------- + | local_computed_block(HBM hit) | external_computed_block(external hit) | new_block(need to dump) | + ---------------------------------------------------------------------------------------------------- + | hbm_hit_block_num | LOAD | new_blocks_num | + ---------------------------------------------------------------------------------------------------- + | total_hit_block_num | + ---------------------------------------------------------------------------------------------------- + | scheduled_block_num | + """ + + new_blocks_num = new_tokens // self.block_size + hbm_hit_block_num = req_meta.hbm_hit_block_num + total_hit_block_num = req_meta.total_hit_block_num + scheduled_block_num = total_hit_block_num + new_blocks_num + ucm_block_ids = req_meta.ucm_block_ids + + dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num] + if need_load: + dump_vllm_block_ids = vllm_block_ids[ + total_hit_block_num:scheduled_block_num + ] + else: + dump_vllm_block_ids = vllm_block_ids + + # after this round, req_meta will be updated + req_meta.total_hit_block_num = scheduled_block_num + + load_ucm_block_ids, load_vllm_block_ids = [], [] + if need_load: + load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num] + load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num] + + return RequestDispatchMeta( + (load_ucm_block_ids, load_vllm_block_ids), + (dump_ucm_block_ids, dump_vllm_block_ids), + ) + def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: + requests_dispatch_meta = {} + # for new request, we need to load and dump + for request in scheduler_output.scheduled_new_reqs: + request_id, vllm_block_ids = request.req_id, request.block_ids[0] + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + vllm_block_ids, + ) - scheduled_request = [ - (req.req_id, req.block_ids) for req in scheduler_output.scheduled_new_reqs - ] - - def get_requests(cached_request_data): - # 0.9.1 - if isinstance(cached_request_data, list): - return [ - ( - request_data.req_id, - request_data.new_block_ids, - ) - for request_data in cached_request_data - ] + # for cached request, there are 3 situation: + # 1. chunked prefill: we only need dump + # 2. resumed: we need to handle like new request + # 3. TODO decode stage: nothing happened + scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs + if not isinstance(scheduled_cached_reqs, list): # >= 0.9.2 - else: - return [ - ( - req_id, - cached_request_data.new_block_ids[i], + for i, request_id in enumerate(scheduled_cached_reqs.req_ids): + if scheduler_output.num_scheduled_tokens[request_id] == 1: + # decode stage + continue + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + scheduled_cached_reqs.new_block_ids[i][0], + scheduled_cached_reqs.resumed_from_preemption[i], + ) + else: + for request in scheduled_cached_reqs: + request_id = request.request_id + if scheduler_output.num_scheduled_tokens[request_id] == 1: + # decode stage + continue + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.new_block_ids[0], + request.resumed_from_preemption, ) - for i, req_id in enumerate(cached_request_data.req_ids) - ] - - scheduled_request.extend(get_requests(scheduler_output.scheduled_cached_reqs)) - - for request_id, vllm_block_ids in scheduled_request: - req_meta = self.request_meta.get(request_id) - if req_meta: - # we only save scheduled tokens in this step - new_tokens = scheduler_output.num_scheduled_tokens[request_id] - new_blocks_num = new_tokens // self.block_size - req_meta.scheduled_block_num = ( - req_meta.total_hit_block_num + new_blocks_num - ) - req_meta.vllm_block_ids = vllm_block_ids[0] - # we need to clear self.request_meta - request_meta = self.request_meta - self.request_meta = {} + # clear finished request + for request_id in scheduler_output.finished_req_ids: + self.requests_meta.pop(request_id, None) - return UCMConnectorMetadata(request_meta) + return UCMConnectorMetadata(requests_dispatch_meta) def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): if len(self.kv_caches) > 0: @@ -426,16 +474,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: request_to_task: dict[str, Optional[Task]] = {} req_to_layer = {} for request_id, request in metadata.request_meta.items(): - hbm_hit_block_num = request.hbm_hit_block_num - total_hit_block_num = request.total_hit_block_num - if hbm_hit_block_num == total_hit_block_num: - # no external hit blocks + if len(request.load_block_ids[0]) == 0: continue - vllm_block_ids = request.vllm_block_ids[ - hbm_hit_block_num:total_hit_block_num - ] - ucm_block_ids = request.ucm_block_ids[hbm_hit_block_num:total_hit_block_num] + ucm_block_ids, vllm_block_ids = request.load_block_ids if self.load_only_first_rank: can_load = self.rank == 0 task, layer_to_tensors, total_block_num = ( @@ -483,15 +525,12 @@ def wait_for_save(self) -> None: request_to_task: dict[str, Task] = {} request_to_blocks: dict[str, list[str]] = {} for request_id, request in metadata.request_meta.items(): - total_hit_block_num = request.total_hit_block_num - scheduled_block_num = request.scheduled_block_num - if scheduled_block_num == total_hit_block_num: - # no need to save block + if len(request.dump_block_ids[0]) == 0: continue - ucm_block_ids = request.ucm_block_ids[total_hit_block_num:] + ucm_block_ids, vllm_block_ids = request.dump_block_ids rets = self.store.create(ucm_block_ids) - end = total_hit_block_num + end = 0 for i, ret in enumerate(rets): if ret != 0: logger.error( @@ -500,10 +539,8 @@ def wait_for_save(self) -> None: break end += 1 - if end == total_hit_block_num: - continue - ucm_block_ids = ucm_block_ids[total_hit_block_num:end] - vllm_block_ids = request.vllm_block_ids[total_hit_block_num:end] + ucm_block_ids = ucm_block_ids[:end] + vllm_block_ids = vllm_block_ids[:end] request_to_task[request_id] = self._generate_task( vllm_block_ids, ucm_block_ids, self.store.dump ) @@ -585,11 +622,60 @@ def get_finished( raise NotImplementedError +class UCMMockConnector(UCMDirectConnector): + """ + This Connector can control hit ratio, for example: if your hit ratio is 100%, + you can set "hit_ratio" by config or env_vars, then get_num_new_matched_tokens() + will reduce hit_tokens under the hit_ratio you set. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + self._hit_ratio = float( + self._vllm_config.kv_transfer_config.kv_connector_extra_config["hit_ratio"] + ) + logger.info(f"hit_ratio: {self._hit_ratio}") + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + hit_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens) + expect_hit_tokens = int(self._hit_ratio * request.num_prompt_tokens) + if hit_tokens <= expect_hit_tokens: + return hit_tokens, False + expect_hit_block_num = expect_hit_tokens // self.block_size + request_meta = self.requests_meta[request.request_id] + request_meta.total_hit_block_num = expect_hit_block_num + request_meta.hbm_hit_block_num = min( + expect_hit_block_num, request_meta.hbm_hit_block_num + ) + + logger.info( + "Hijacked By MockConnector," + f"request_id: {request.request_id}, " + f"total_blocks_num: {len(request_meta.ucm_block_ids)}, " + f"hit hbm: {request_meta.hbm_hit_block_num}, " + f"hit external: {request_meta.total_hit_block_num - request_meta.hbm_hit_block_num}" + ) + + return expect_hit_block_num * self.block_size, False + + class UCMConnector(KVConnectorBase_V1): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) + self.connector: KVConnectorBase_V1 # TODO new conn by config - self.connector = UCMDirectConnector(vllm_config, role) + if ( + self._vllm_config.kv_transfer_config is not None + and "hit_ratio" + in self._vllm_config.kv_transfer_config.kv_connector_extra_config + ): + self.connector = UCMMockConnector(vllm_config, role) + else: + self.connector = UCMDirectConnector(vllm_config, role) def get_num_new_matched_tokens( self,