3131#include " CThreadSettings.h"
3232
3333#include < ATen/Parallel.h>
34+ #include < ATen/ops/cat.h>
3435#include < torch/csrc/api/include/torch/types.h>
3536#include < torch/script.h>
3637
@@ -44,26 +45,39 @@ torch::Tensor infer(torch::jit::script::Module& module_,
4445 std::vector<torch::jit::IValue> inputs;
4546 inputs.reserve (1 + request.s_SecondaryArguments .size ());
4647
47- std::array<std::int64_t , 2 > dimensions = {request.s_NumberInferences ,
48- request.s_NumberInputTokens };
48+ std::array<std::int64_t , 2 > dimensions = {1 , request.s_NumberInputTokens };
4949 at::IntArrayRef inputSize{dimensions};
5050
51- // Sequence tokens.
52- inputs.emplace_back (torch::from_blob (static_cast <void *>(request.s_Tokens .data ()),
53- inputSize, at::dtype (torch::kInt64 )));
54- // Attention mask.
55- for (auto & args : request.s_SecondaryArguments ) {
56- inputs.emplace_back (torch::from_blob (static_cast <void *>(args.data ()),
57- inputSize, at::dtype (torch::kInt64 )));
58- }
51+ std::vector<at::Tensor> all;
5952
6053 torch::InferenceMode inferenceModeGuard;
61- auto result = module_.forward (inputs);
62- if (result.isTuple ()) {
63- // For transformers the result tensor is the first element in a tuple.
64- return result.toTuple ()->elements ()[0 ].toTensor ();
54+
55+ for (int i = 0 ; i < request.s_NumberInferences ; i++) {
56+
57+ std::size_t offset = i * request.s_NumberInputTokens ;
58+
59+ // Sequence tokens.
60+ inputs.emplace_back (
61+ torch::from_blob (static_cast <void *>(request.s_Tokens .data () + offset),
62+ inputSize, at::dtype (torch::kInt64 )));
63+ // Attention mask etc
64+ for (auto & args : request.s_SecondaryArguments ) {
65+ inputs.emplace_back (torch::from_blob (static_cast <void *>(args.data () + offset),
66+ inputSize, at::dtype (torch::kInt64 )));
67+ }
68+
69+ auto output = module_.forward (inputs);
70+ if (output.isTuple ()) {
71+ // For transformers the result tensor is the first element in a tuple.
72+ all.push_back (output.toTuple ()->elements ()[0 ].toTensor ());
73+ } else {
74+ all.push_back (output.toTensor ());
75+ }
76+
77+ inputs.clear ();
6578 }
66- return result.toTensor ();
79+
80+ return at::cat (all, 0 );
6781}
6882
6983bool handleRequest (ml::torch::CCommandParser::CRequestCacheInterface& cache,
0 commit comments