Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xllm/core/framework/kv_cache/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ cc_library(
NAME
kv_cache
HDRS
embedding_allocator.h
token_cache_allocator.h
$<$<BOOL:${USE_NPU}>:hccl_kv_cache_transfer.h>
kv_cache.h
kv_cache_event.h
Expand All @@ -16,7 +16,7 @@ cc_library(
$<$<BOOL:${USE_NPU}>:spec_kv_cache_transfer.h>
kv_cache_store.h
SRCS
embedding_allocator.cpp
token_cache_allocator.cpp
$<$<BOOL:${USE_NPU}>:hccl_kv_cache_transfer.cpp>
kv_cache.cpp
kv_cache_transfer.cpp
Expand Down
105 changes: 0 additions & 105 deletions xllm/core/framework/kv_cache/embedding_allocator.cpp

This file was deleted.

74 changes: 0 additions & 74 deletions xllm/core/framework/kv_cache/embedding_allocator.h

This file was deleted.

2 changes: 0 additions & 2 deletions xllm/core/framework/kv_cache/kv_cache_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class KVCacheTransfer {
int64_t dst_v_cache_id;
std::vector<uint64_t> src_blocks;
std::vector<uint64_t> dst_blocks;
std::vector<uint64_t> src_embed_ids;
std::vector<uint64_t> dst_embed_ids;
};

KVCacheTransfer() = default;
Expand Down
72 changes: 0 additions & 72 deletions xllm/core/framework/kv_cache/spec_kv_cache_transfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,33 +111,11 @@ void SpecKVCacheTransfer::_allocate_kv_cache(
}
}

void SpecKVCacheTransfer::allocate_embedding(
std::shared_ptr<EmbeddingAllocator> embedding_allocator,
const std::vector<int64_t>& embedding_shape,
torch::ScalarType dtype,
torch::Device device) {
const auto& it = kScalarTypeToDtype.find(dtype);
CHECK(it != kScalarTypeToDtype.cend()) << "Unsupport data type : " << dtype;
auto ge_dtype = it->second;
CacheDesc embed_cache_desc;
embed_cache_desc.num_tensors = 1;
embed_cache_desc.data_type = ge_dtype;
embed_cache_desc.shape = embedding_shape;
CHECK_LDD_RET(llm_data_dist_->AllocateCache(embed_cache_desc, embed_cache_));

embed_host_cache_.cache_desc = embed_cache_.cache_desc;
embed_host_cache_.cache_desc.placement = CachePlacement::kHost;
CHECK_EQ(embed_host_cache_.cache_desc.num_tensors, 1);
embed_host_cache_.tensor_addrs.emplace_back(reinterpret_cast<uint64_t>(
embedding_allocator->get_embeddings_cache_ptr()));
}

void SpecKVCacheTransfer::free_kv_cache() {
llm_data_dist_->DeallocateCache(k_cache_.cache_id);
llm_data_dist_->DeallocateCache(v_cache_.cache_id);
llm_data_dist_->DeallocateCache(spec_k_cache_.cache_id);
llm_data_dist_->DeallocateCache(spec_v_cache_.cache_id);
llm_data_dist_->DeallocateCache(embed_cache_.cache_id);
}

bool SpecKVCacheTransfer::pull_kv_blocks(
Expand All @@ -160,23 +138,13 @@ bool SpecKVCacheTransfer::pull_kv_blocks(
CacheIndex spec_v_cache_index{src_cluster_id, spec_v_cache_.cache_id};
CHECK_LDD_RET(llm_data_dist_->PullKvBlocks(
spec_v_cache_index, spec_v_cache_, src_blocks, dst_blocks));

CacheIndex embed_cache_index{src_cluster_id, embed_cache_.cache_id};
CHECK_LDD_RET(llm_data_dist_->PullKvBlocks(embed_cache_index,
embed_cache_,
{src_blocks.back()},
{dst_blocks.back()}));
return true;
}

bool SpecKVCacheTransfer::push_kv_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer,
bool is_spec_draft) {
if (!layer_synchronizer) {
return push_embed_blocks(merged_kv_infos);
}

if (is_spec_draft) {
return push_kv_blocks_spec(merged_kv_infos, layer_synchronizer);
}
Expand Down Expand Up @@ -244,24 +212,6 @@ bool SpecKVCacheTransfer::push_kv_blocks_spec(
return true;
}

bool SpecKVCacheTransfer::push_embed_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos) {
for (const auto& pair : merged_kv_infos) {
const KVCacheInfo& kv_info = pair.second;
CacheIndex cache_index{kv_info.dst_cluster_id, embed_cache_.cache_id};
KvCacheExtParam ext_param{};
ext_param.src_layer_range = std::pair<int32_t, int32_t>(0, 0);
ext_param.dst_layer_range = std::pair<int32_t, int32_t>(0, 0);
ext_param.tensor_num_per_layer = 1;
CHECK_LDD_RET(llm_data_dist_->PushKvBlocks(embed_cache_,
cache_index,
kv_info.src_embed_ids,
kv_info.dst_embed_ids,
ext_param));
}
return true;
}

folly::SemiFuture<bool> SpecKVCacheTransfer::push_kv_blocks_async(
const std::vector<TransferKVInfo>& transfer_kv_infos,
const ParallelArgs& parallel_args,
Expand Down Expand Up @@ -341,8 +291,6 @@ void SpecKVCacheTransfer::merge_kv_blocks(
kv_info.dst_blocks.insert(kv_info.dst_blocks.end(),
info.remote_blocks_ids.begin(),
info.remote_blocks_ids.end());
kv_info.src_embed_ids.push_back(kv_info.src_blocks.back());
kv_info.dst_embed_ids.push_back(kv_info.dst_blocks.back());
merged_kv_infos[key] = std::move(kv_info);
} else {
merged_kv_infos[key].src_blocks.insert(
Expand All @@ -353,28 +301,8 @@ void SpecKVCacheTransfer::merge_kv_blocks(
merged_kv_infos[key].dst_blocks.end(),
info.remote_blocks_ids.begin(),
info.remote_blocks_ids.end());
merged_kv_infos[key].src_embed_ids.push_back(
merged_kv_infos[key].src_blocks.back());
merged_kv_infos[key].dst_embed_ids.push_back(
merged_kv_infos[key].dst_blocks.back());
}
}
}
}

void SpecKVCacheTransfer::copy_blocks(const std::vector<int>& blocks,
bool h2d) {
std::vector<uint64_t> _blocks;
_blocks.reserve(blocks.size());
for (const auto& block : blocks) {
_blocks.push_back(static_cast<uint64_t>(block));
}
if (h2d) {
CHECK_LDD_RET(llm_data_dist_->CopyKvBlocks(
embed_host_cache_, embed_cache_, _blocks, {_blocks}));
} else {
CHECK_LDD_RET(llm_data_dist_->CopyKvBlocks(
embed_cache_, embed_host_cache_, _blocks, {_blocks}));
}
}
} // namespace xllm
14 changes: 0 additions & 14 deletions xllm/core/framework/kv_cache/spec_kv_cache_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License.

#pragma once

#include "embedding_allocator.h"
#include "framework/parallel_state/parallel_args.h"
#include "llm_data_dist_transfer.h"

Expand Down Expand Up @@ -52,12 +51,6 @@ class SpecKVCacheTransfer : public LlmDataDistTransfer {
Cache& k_cache,
Cache& v_cache);

void allocate_embedding(
std::shared_ptr<EmbeddingAllocator> embedding_allocator,
const std::vector<int64_t>& embedding_shape,
torch::ScalarType dtype,
torch::Device device);

void free_kv_cache() override;

bool pull_kv_blocks(const uint64_t src_cluster_id,
Expand All @@ -82,23 +75,16 @@ class SpecKVCacheTransfer : public LlmDataDistTransfer {
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
std::shared_ptr<NPULayerSynchronizerImpl>& layer_synchronizer);

bool push_embed_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos);

void merge_kv_blocks(
std::unordered_map<std::string, KVCacheInfo>& merged_kv_infos,
const std::vector<TransferKVInfo>& transfer_kv_infos,
const ParallelArgs& parallel_args) override;

void copy_blocks(const std::vector<int>& blocks, bool h2d);

private:
int64_t spec_num_layers_;

Cache spec_k_cache_;
Cache spec_v_cache_;
Cache embed_cache_;
Cache embed_host_cache_;

Cache host_cache;
std::vector<std::vector<uint16_t>> buffers;
Expand Down
Loading