Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit fa72355

Browse files
authored
Merge pull request #1053 from janhq/fix/dynamically-get-cuda-dependency-version
fix: dynamically get cuda toolkit version
2 parents a5a30c2 + 65de91a commit fa72355

File tree

3 files changed

+139
-45
lines changed

3 files changed

+139
-45
lines changed

engine/commands/engine_init_cmd.cc

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
// clang-format on
1010
#include "utils/cuda_toolkit_utils.h"
1111
#include "utils/engine_matcher_utils.h"
12+
#if defined(_WIN32) || defined(__linux__)
13+
#include "utils/file_manager_utils.h"
14+
#endif
1215

1316
namespace commands {
1417

@@ -60,21 +63,22 @@ bool EngineInitCmd::Exec() const {
6063
variants.push_back(asset_name);
6164
}
6265

63-
auto cuda_version = system_info_utils::GetCudaVersion();
64-
LOG_INFO << "engineName_: " << engineName_;
65-
LOG_INFO << "CUDA version: " << cuda_version;
66-
std::string matched_variant = "";
66+
auto cuda_driver_version = system_info_utils::GetCudaVersion();
67+
LOG_INFO << "Engine: " << engineName_
68+
<< ", CUDA driver version: " << cuda_driver_version;
69+
70+
std::string matched_variant{""};
6771
if (engineName_ == "cortex.tensorrt-llm") {
6872
matched_variant = engine_matcher_utils::ValidateTensorrtLlm(
69-
variants, system_info.os, cuda_version);
73+
variants, system_info.os, cuda_driver_version);
7074
} else if (engineName_ == "cortex.onnx") {
7175
matched_variant = engine_matcher_utils::ValidateOnnx(
7276
variants, system_info.os, system_info.arch);
7377
} else if (engineName_ == "cortex.llamacpp") {
7478
auto suitable_avx = engine_matcher_utils::GetSuitableAvxVariant();
7579
matched_variant = engine_matcher_utils::Validate(
7680
variants, system_info.os, system_info.arch, suitable_avx,
77-
cuda_version);
81+
cuda_driver_version);
7882
}
7983
LOG_INFO << "Matched variant: " << matched_variant;
8084
if (matched_variant.empty()) {
@@ -105,17 +109,46 @@ bool EngineInitCmd::Exec() const {
105109
}}};
106110

107111
DownloadService download_service;
108-
download_service.AddDownloadTask(downloadTask, [](const std::string&
109-
absolute_path,
110-
bool unused) {
112+
download_service.AddDownloadTask(downloadTask, [this](
113+
const std::string&
114+
absolute_path,
115+
bool unused) {
111116
// try to unzip the downloaded file
112117
std::filesystem::path downloadedEnginePath{absolute_path};
113118
LOG_INFO << "Downloaded engine path: "
114119
<< downloadedEnginePath.string();
115120

116-
archive_utils::ExtractArchive(
117-
downloadedEnginePath.string(),
118-
downloadedEnginePath.parent_path().parent_path().string());
121+
std::filesystem::path extract_path =
122+
downloadedEnginePath.parent_path().parent_path();
123+
124+
archive_utils::ExtractArchive(downloadedEnginePath.string(),
125+
extract_path.string());
126+
#if defined(_WIN32) || defined(__linux__)
127+
// FIXME: hacky try to copy the file. Remove this when we are able to set the library path
128+
auto engine_path = extract_path / engineName_;
129+
LOG_INFO << "Source path: " << engine_path.string();
130+
auto executable_path =
131+
file_manager_utils::GetExecutableFolderContainerPath();
132+
for (const auto& entry :
133+
std::filesystem::recursive_directory_iterator(engine_path)) {
134+
if (entry.is_regular_file() &&
135+
entry.path().extension() != ".gz") {
136+
std::filesystem::path relative_path =
137+
std::filesystem::relative(entry.path(), engine_path);
138+
std::filesystem::path destFile =
139+
executable_path / relative_path;
140+
141+
std::filesystem::create_directories(destFile.parent_path());
142+
std::filesystem::copy_file(
143+
entry.path(), destFile,
144+
std::filesystem::copy_options::overwrite_existing);
145+
146+
std::cout << "Copied: " << entry.path().filename().string()
147+
<< " to " << destFile.string() << std::endl;
148+
}
149+
}
150+
std::cout << "DLL copying completed successfully." << std::endl;
151+
#endif
119152

120153
// remove the downloaded file
121154
// TODO(any) Could not delete file on Windows because it is currently hold by httplib(?)
@@ -128,23 +161,47 @@ bool EngineInitCmd::Exec() const {
128161
LOG_INFO << "Finished!";
129162
});
130163
if (system_info.os == "mac" || engineName_ == "cortex.onnx") {
131-
return false;
164+
// mac and onnx engine does not require cuda toolkit
165+
return true;
132166
}
167+
133168
// download cuda toolkit
134169
const std::string jan_host = "https://catalog.jan.ai";
135170
const std::string cuda_toolkit_file_name = "cuda.tar.gz";
136171
const std::string download_id = "cuda";
137172

138-
auto gpu_driver_version = system_info_utils::GetDriverVersion();
173+
// TODO: we don't have API to retrieve list of cuda toolkit dependencies atm because we hosting it at jan
174+
// will have better logic after https://github.com/janhq/cortex/issues/1046 finished
175+
// for now, assume that we have only 11.7 and 12.4
176+
auto suitable_toolkit_version = "";
177+
if (engineName_ == "cortex.tensorrt-llm") {
178+
// for tensorrt-llm, we need to download cuda toolkit v12.4
179+
suitable_toolkit_version = "12.4";
180+
} else {
181+
// llamacpp
182+
auto cuda_driver_semver =
183+
semantic_version_utils::SplitVersion(cuda_driver_version);
184+
if (cuda_driver_semver.major == 11) {
185+
suitable_toolkit_version = "11.7";
186+
} else if (cuda_driver_semver.major == 12) {
187+
suitable_toolkit_version = "12.4";
188+
}
189+
}
139190

140-
auto cuda_runtime_version =
141-
cuda_toolkit_utils::GetCompatibleCudaToolkitVersion(
142-
gpu_driver_version, system_info.os, engineName_);
191+
// compare cuda driver version with cuda toolkit version
192+
// cuda driver version should be greater than toolkit version to ensure compatibility
193+
if (semantic_version_utils::CompareSemanticVersion(
194+
cuda_driver_version, suitable_toolkit_version) < 0) {
195+
LOG_ERROR << "Your Cuda driver version " << cuda_driver_version
196+
<< " is not compatible with cuda toolkit version "
197+
<< suitable_toolkit_version;
198+
return false;
199+
}
143200

144201
std::ostringstream cuda_toolkit_path;
145-
cuda_toolkit_path << "dist/cuda-dependencies/" << 11.7 << "/"
146-
<< system_info.os << "/"
147-
<< cuda_toolkit_file_name;
202+
cuda_toolkit_path << "dist/cuda-dependencies/"
203+
<< cuda_driver_version << "/" << system_info.os
204+
<< "/" << cuda_toolkit_file_name;
148205

149206
LOG_DEBUG << "Cuda toolkit download url: " << jan_host
150207
<< cuda_toolkit_path.str();

engine/utils/engine_matcher_utils.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
#include <trantor/utils/Logger.h>
12
#include <algorithm>
2-
#include <iostream>
33
#include <iterator>
44
#include <regex>
55
#include <string>
@@ -93,9 +93,19 @@ inline std::string GetSuitableCudaVariant(
9393
bestMatchMinor = variantMinor;
9494
}
9595
}
96-
} else if (cuda_version.empty() && selectedVariant.empty()) {
97-
// If no CUDA version is provided, select the variant without any CUDA in the name
98-
selectedVariant = variant;
96+
}
97+
}
98+
99+
// If no CUDA version is provided, select the variant without any CUDA in the name
100+
if (selectedVariant.empty()) {
101+
LOG_WARN
102+
<< "No suitable CUDA variant found, selecting a variant without CUDA";
103+
for (const auto& variant : variants) {
104+
if (variant.find("cuda") == std::string::npos) {
105+
selectedVariant = variant;
106+
LOG_INFO << "Found variant without CUDA: " << selectedVariant << "\n";
107+
break;
108+
}
99109
}
100110
}
101111

@@ -177,4 +187,4 @@ inline std::string Validate(const std::vector<std::string>& variants,
177187

178188
return cuda_compatible;
179189
}
180-
} // namespace engine_matcher_utils
190+
} // namespace engine_matcher_utils
Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,61 @@
1+
#include <trantor/utils/Logger.h>
12
#include <sstream>
2-
#include <vector>
33

44
namespace semantic_version_utils {
5-
inline std::vector<int> SplitVersion(const std::string& version) {
6-
std::vector<int> parts;
7-
std::stringstream ss(version);
8-
std::string part;
5+
struct SemVer {
6+
int major;
7+
int minor;
8+
int patch;
9+
};
910

10-
while (std::getline(ss, part, '.')) {
11-
parts.push_back(std::stoi(part));
11+
inline SemVer SplitVersion(const std::string& version) {
12+
if (version.empty()) {
13+
LOG_WARN << "Passed in version is empty!";
1214
}
15+
SemVer semVer = {0, 0, 0}; // default value
16+
std::stringstream ss(version);
17+
std::string part;
1318

14-
while (parts.size() < 3) {
15-
parts.push_back(0);
19+
int index = 0;
20+
while (std::getline(ss, part, '.') && index < 3) {
21+
int value = std::stoi(part);
22+
switch (index) {
23+
case 0:
24+
semVer.major = value;
25+
break;
26+
case 1:
27+
semVer.minor = value;
28+
break;
29+
case 2:
30+
semVer.patch = value;
31+
break;
32+
}
33+
++index;
1634
}
1735

18-
return parts;
36+
return semVer;
1937
}
2038

2139
inline int CompareSemanticVersion(const std::string& version1,
2240
const std::string& version2) {
23-
std::vector<int> v1 = SplitVersion(version1);
24-
std::vector<int> v2 = SplitVersion(version2);
25-
26-
for (size_t i = 0; i < 3; ++i) {
27-
if (v1[i] < v2[i])
28-
return -1;
29-
if (v1[i] > v2[i])
30-
return 1;
31-
}
41+
SemVer v1 = SplitVersion(version1);
42+
SemVer v2 = SplitVersion(version2);
43+
44+
if (v1.major < v2.major)
45+
return -1;
46+
if (v1.major > v2.major)
47+
return 1;
48+
49+
if (v1.minor < v2.minor)
50+
return -1;
51+
if (v1.minor > v2.minor)
52+
return 1;
53+
54+
if (v1.patch < v2.patch)
55+
return -1;
56+
if (v1.patch > v2.patch)
57+
return 1;
58+
3259
return 0;
3360
}
34-
} // namespace semantic_version_utils
61+
} // namespace semantic_version_utils

0 commit comments

Comments
 (0)