diff --git a/CMakeLists.txt b/CMakeLists.txt index d539453..9ff166f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,7 +111,7 @@ else() FetchContent_Declare( repo-ft GIT_REPOSITORY https://github.com/neevaco/FasterTransformer.git - GIT_TAG main + GIT_TAG affa1ef1c175d03db8ff5b14824cc58dd2c52c2b GIT_SHALLOW ON ) endif() diff --git a/src/libfastertransformer.cc b/src/libfastertransformer.cc index f2bbc0e..f478c22 100644 --- a/src/libfastertransformer.cc +++ b/src/libfastertransformer.cc @@ -54,6 +54,7 @@ #include "src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h" #include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h" #include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h" #include "src/fastertransformer/triton_backend/t5/T5TritonModel.h" @@ -328,6 +329,21 @@ std::shared_ptr ModelState::ModelFactory( } else if (data_type == "bf16") { ft_model = std::make_shared>( tp, pp, custom_ar, model_dir, int8_mode, is_sparse, remove_padding); +#endif + } + } else if (model_type == "llama") { + const int int8_mode = param_get_int(param, "int8_mode"); + + if (data_type == "fp16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, int8_mode); + } else if (data_type == "fp32") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, int8_mode); +#ifdef ENABLE_BF16 + } else if (data_type == "bf16") { + ft_model = std::make_shared>( + tp, pp, custom_ar, model_dir, int8_mode); #endif } } else if (model_type == "deberta") {