99// clang-format on
1010#include " utils/cuda_toolkit_utils.h"
1111#include " utils/engine_matcher_utils.h"
12+ #if defined(_WIN32) || defined(__linux__)
13+ #include " utils/file_manager_utils.h"
14+ #endif
1215
1316namespace commands {
1417
@@ -60,21 +63,22 @@ bool EngineInitCmd::Exec() const {
6063 variants.push_back (asset_name);
6164 }
6265
63- auto cuda_version = system_info_utils::GetCudaVersion ();
64- LOG_INFO << " engineName_: " << engineName_;
65- LOG_INFO << " CUDA version: " << cuda_version;
66- std::string matched_variant = " " ;
66+ auto cuda_driver_version = system_info_utils::GetCudaVersion ();
67+ LOG_INFO << " Engine: " << engineName_
68+ << " , CUDA driver version: " << cuda_driver_version;
69+
70+ std::string matched_variant{" " };
6771 if (engineName_ == " cortex.tensorrt-llm" ) {
6872 matched_variant = engine_matcher_utils::ValidateTensorrtLlm (
69- variants, system_info.os , cuda_version );
73+ variants, system_info.os , cuda_driver_version );
7074 } else if (engineName_ == " cortex.onnx" ) {
7175 matched_variant = engine_matcher_utils::ValidateOnnx (
7276 variants, system_info.os , system_info.arch );
7377 } else if (engineName_ == " cortex.llamacpp" ) {
7478 auto suitable_avx = engine_matcher_utils::GetSuitableAvxVariant ();
7579 matched_variant = engine_matcher_utils::Validate (
7680 variants, system_info.os , system_info.arch , suitable_avx,
77- cuda_version );
81+ cuda_driver_version );
7882 }
7983 LOG_INFO << " Matched variant: " << matched_variant;
8084 if (matched_variant.empty ()) {
@@ -105,17 +109,46 @@ bool EngineInitCmd::Exec() const {
105109 }}};
106110
107111 DownloadService download_service;
108- download_service.AddDownloadTask (downloadTask, [](const std::string&
109- absolute_path,
110- bool unused) {
112+ download_service.AddDownloadTask (downloadTask, [this ](
113+ const std::string&
114+ absolute_path,
115+ bool unused) {
111116 // try to unzip the downloaded file
112117 std::filesystem::path downloadedEnginePath{absolute_path};
113118 LOG_INFO << " Downloaded engine path: "
114119 << downloadedEnginePath.string ();
115120
116- archive_utils::ExtractArchive (
117- downloadedEnginePath.string (),
118- downloadedEnginePath.parent_path ().parent_path ().string ());
121+ std::filesystem::path extract_path =
122+ downloadedEnginePath.parent_path ().parent_path ();
123+
124+ archive_utils::ExtractArchive (downloadedEnginePath.string (),
125+ extract_path.string ());
126+ #if defined(_WIN32) || defined(__linux__)
127+ // FIXME: hacky try to copy the file. Remove this when we are able to set the library path
128+ auto engine_path = extract_path / engineName_;
129+ LOG_INFO << " Source path: " << engine_path.string ();
130+ auto executable_path =
131+ file_manager_utils::GetExecutableFolderContainerPath ();
132+ for (const auto & entry :
133+ std::filesystem::recursive_directory_iterator (engine_path)) {
134+ if (entry.is_regular_file () &&
135+ entry.path ().extension () != " .gz" ) {
136+ std::filesystem::path relative_path =
137+ std::filesystem::relative (entry.path (), engine_path);
138+ std::filesystem::path destFile =
139+ executable_path / relative_path;
140+
141+ std::filesystem::create_directories (destFile.parent_path ());
142+ std::filesystem::copy_file (
143+ entry.path (), destFile,
144+ std::filesystem::copy_options::overwrite_existing);
145+
146+ std::cout << " Copied: " << entry.path ().filename ().string ()
147+ << " to " << destFile.string () << std::endl;
148+ }
149+ }
150+ std::cout << " DLL copying completed successfully." << std::endl;
151+ #endif
119152
120153 // remove the downloaded file
121154 // TODO(any) Could not delete file on Windows because it is currently hold by httplib(?)
@@ -128,23 +161,47 @@ bool EngineInitCmd::Exec() const {
128161 LOG_INFO << " Finished!" ;
129162 });
130163 if (system_info.os == " mac" || engineName_ == " cortex.onnx" ) {
131- return false ;
164+ // mac and onnx engine does not require cuda toolkit
165+ return true ;
132166 }
167+
133168 // download cuda toolkit
134169 const std::string jan_host = " https://catalog.jan.ai" ;
135170 const std::string cuda_toolkit_file_name = " cuda.tar.gz" ;
136171 const std::string download_id = " cuda" ;
137172
138- auto gpu_driver_version = system_info_utils::GetDriverVersion ();
173+ // TODO: we don't have API to retrieve list of cuda toolkit dependencies atm because we hosting it at jan
174+ // will have better logic after https://github.com/janhq/cortex/issues/1046 finished
175+ // for now, assume that we have only 11.7 and 12.4
176+ auto suitable_toolkit_version = " " ;
177+ if (engineName_ == " cortex.tensorrt-llm" ) {
178+ // for tensorrt-llm, we need to download cuda toolkit v12.4
179+ suitable_toolkit_version = " 12.4" ;
180+ } else {
181+ // llamacpp
182+ auto cuda_driver_semver =
183+ semantic_version_utils::SplitVersion (cuda_driver_version);
184+ if (cuda_driver_semver.major == 11 ) {
185+ suitable_toolkit_version = " 11.7" ;
186+ } else if (cuda_driver_semver.major == 12 ) {
187+ suitable_toolkit_version = " 12.4" ;
188+ }
189+ }
139190
140- auto cuda_runtime_version =
141- cuda_toolkit_utils::GetCompatibleCudaToolkitVersion (
142- gpu_driver_version, system_info.os , engineName_);
191+ // compare cuda driver version with cuda toolkit version
192+ // cuda driver version should be greater than toolkit version to ensure compatibility
193+ if (semantic_version_utils::CompareSemanticVersion (
194+ cuda_driver_version, suitable_toolkit_version) < 0 ) {
195+ LOG_ERROR << " Your Cuda driver version " << cuda_driver_version
196+ << " is not compatible with cuda toolkit version "
197+ << suitable_toolkit_version;
198+ return false ;
199+ }
143200
144201 std::ostringstream cuda_toolkit_path;
145- cuda_toolkit_path << " dist/cuda-dependencies/" << 11.7 << " / "
146- << system_info. os << " /"
147- << cuda_toolkit_file_name;
202+ cuda_toolkit_path << " dist/cuda-dependencies/"
203+ << cuda_driver_version << " /" << system_info. os
204+ << " / " << cuda_toolkit_file_name;
148205
149206 LOG_DEBUG << " Cuda toolkit download url: " << jan_host
150207 << cuda_toolkit_path.str ();
0 commit comments