Skip to content
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ else()
FetchContent_Declare(
repo-ft
GIT_REPOSITORY https://github.com/neevaco/FasterTransformer.git
GIT_TAG main
GIT_TAG affa1ef1c175d03db8ff5b14824cc58dd2c52c2b
GIT_SHALLOW ON
)
endif()
Expand Down
16 changes: 16 additions & 0 deletions src/libfastertransformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "src/fastertransformer/triton_backend/gptj/GptJTritonModelInstance.h"
#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModel.h"
#include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h"
#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h"
#include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h"
#include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h"
#include "src/fastertransformer/triton_backend/t5/T5TritonModel.h"
Expand Down Expand Up @@ -328,6 +329,21 @@ std::shared_ptr<AbstractTransformerModel> ModelState::ModelFactory(
} else if (data_type == "bf16") {
ft_model = std::make_shared<BertTritonModel<__nv_bfloat16>>(
tp, pp, custom_ar, model_dir, int8_mode, is_sparse, remove_padding);
#endif
}
} else if (model_type == "llama") {
const int int8_mode = param_get_int(param, "int8_mode");

if (data_type == "fp16") {
ft_model = std::make_shared<LlamaTritonModel<half>>(
tp, pp, custom_ar, model_dir, int8_mode);
} else if (data_type == "fp32") {
ft_model = std::make_shared<LlamaTritonModel<float>>(
tp, pp, custom_ar, model_dir, int8_mode);
#ifdef ENABLE_BF16
} else if (data_type == "bf16") {
ft_model = std::make_shared<LlamaTritonModel<__nv_bfloat16>>(
tp, pp, custom_ar, model_dir, int8_mode);
#endif
}
} else if (model_type == "deberta") {
Expand Down