44#include < fstream>
55#include < functional>
66#include < mutex>
7+ #include < condition_variable>
78#include < regex>
89#include < set>
910#include < string>
@@ -2036,6 +2037,8 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
20362037 size_t total_tensors_processed = 0 ;
20372038 const size_t total_tensors_to_process = processed_tensor_storages.size ();
20382039 const int64_t t_start = ggml_time_ms ();
2040+ std::mutex mtx;
2041+ std::condition_variable cv;
20392042
20402043 for (size_t file_index = 0 ; file_index < file_paths_.size (); file_index++) {
20412044 std::string file_path = file_paths_[file_index];
@@ -2065,37 +2068,60 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
20652068 n_threads = 1 ;
20662069 }
20672070
2068- std::atomic<size_t > tensor_idx (0 );
2069- std::atomic<bool > failed (false );
2071+ std::ifstream single_file;
2072+ bool is_single_file (false );
2073+
2074+ struct zip_t * zip = NULL ;
2075+ if (is_zip) {
2076+ zip = zip_open (file_path.c_str (), 0 , ' r' );
2077+ if (zip == NULL ) {
2078+ LOG_ERROR (" failed to open zip '%s'" , file_path.c_str ());
2079+ success = false ;
2080+ break ;
2081+ }
2082+ } else {
2083+ const char * load_single_file = getenv (" SD_LOAD_MODEL_SINGLEFILE" );
2084+ if (load_single_file && *load_single_file == ' 1' ) {
2085+ single_file.open (file_path, std::ios::binary);
2086+ if (!single_file.is_open ()) {
2087+ LOG_ERROR (" failed to open '%s'" , file_path.c_str ());
2088+ success = false ;
2089+ break ;
2090+ }
2091+ }
2092+ }
2093+
2094+ size_t tensor_idx (0 );
2095+ bool loading_failed (false );
20702096 std::vector<std::thread> workers;
20712097
20722098 for (int i = 0 ; i < n_threads; ++i) {
20732099 workers.emplace_back ([&, file_path, is_zip]() {
20742100 std::ifstream file;
2075- struct zip_t * zip = NULL ;
2076- if (is_zip) {
2077- zip = zip_open (file_path.c_str (), 0 , ' r' );
2078- if (zip == NULL ) {
2079- LOG_ERROR (" failed to open zip '%s'" , file_path.c_str ());
2080- failed = true ;
2081- return ;
2082- }
2083- } else {
2101+ bool failed (false );
2102+
2103+ if (!is_zip && !is_single_file) {
20842104 file.open (file_path, std::ios::binary);
20852105 if (!file.is_open ()) {
20862106 LOG_ERROR (" failed to open '%s'" , file_path.c_str ());
20872107 failed = true ;
2088- return ;
20892108 }
20902109 }
20912110
20922111 std::vector<uint8_t > read_buffer;
20932112 std::vector<uint8_t > convert_buffer;
20942113
20952114 while (true ) {
2096- size_t idx = tensor_idx.fetch_add (1 );
2097- if (idx >= file_tensors.size () || failed) {
2098- break ;
2115+ size_t idx;
2116+
2117+ {
2118+ std::lock_guard<std::mutex> lock (mtx);
2119+ idx = tensor_idx++;
2120+ loading_failed = loading_failed || failed;
2121+ if (idx >= file_tensors.size () || loading_failed) {
2122+ cv.notify_one ();
2123+ break ;
2124+ }
20992125 }
21002126
21012127 const TensorStorage& tensor_storage = *file_tensors[idx];
@@ -2104,7 +2130,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21042130 if (!on_new_tensor_cb (tensor_storage, &dst_tensor)) {
21052131 LOG_WARN (" process tensor failed: '%s'" , tensor_storage.name .c_str ());
21062132 failed = true ;
2107- break ;
2133+ continue ;
21082134 }
21092135
21102136 if (dst_tensor == NULL ) {
@@ -2115,6 +2141,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21152141
21162142 auto read_data = [&](char * buf, size_t n) {
21172143 if (zip != NULL ) {
2144+ std::lock_guard<std::mutex> lock (mtx);
21182145 zip_entry_openbyindex (zip, tensor_storage.index_in_zip );
21192146 size_t entry_size = zip_entry_size (zip);
21202147 if (entry_size != n) {
@@ -2134,6 +2161,17 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
21342161 read_time_ms += curr_time_ms - prev_time_ms;
21352162 }
21362163 zip_entry_close (zip);
2164+ } else if (is_single_file) {
2165+ std::lock_guard<std::mutex> lock (mtx);
2166+ prev_time_ms = ggml_time_ms ();
2167+ single_file.seekg (tensor_storage.offset );
2168+ single_file.read (buf, n);
2169+ curr_time_ms = ggml_time_ms ();
2170+ read_time_ms += curr_time_ms - prev_time_ms;
2171+ if (!single_file) {
2172+ LOG_ERROR (" read tensor data failed: '%s'" , file_path.c_str ());
2173+ failed = true ;
2174+ }
21372175 } else {
21382176 prev_time_ms = ggml_time_ms ();
21392177 file.seekg (tensor_storage.offset );
@@ -2245,27 +2283,32 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread
22452283 copy_to_backend_time_ms += curr_time_ms - prev_time_ms;
22462284 }
22472285 }
2286+
22482287 }
2249- if (zip != NULL ) {
2250- zip_close (zip);
2251- }
2288+
22522289 });
22532290 }
22542291
2255- while (true ) {
2256- size_t current_idx = tensor_idx.load ();
2257- if (current_idx >= file_tensors.size () || failed) {
2258- break ;
2292+ {
2293+ std::unique_lock<std::mutex> lock (mtx);
2294+ while (true ) {
2295+ if (tensor_idx >= file_tensors.size () || loading_failed) {
2296+ break ;
2297+ }
2298+ pretty_progress (total_tensors_processed + tensor_idx, total_tensors_to_process, (ggml_time_ms () - t_start) / 1000 .0f );
2299+ cv.wait_for (lock, std::chrono::milliseconds (200 ));
22592300 }
2260- pretty_progress (total_tensors_processed + current_idx, total_tensors_to_process, (ggml_time_ms () - t_start) / 1000 .0f );
2261- std::this_thread::sleep_for (std::chrono::milliseconds (200 ));
22622301 }
22632302
22642303 for (auto & w : workers) {
22652304 w.join ();
22662305 }
22672306
2268- if (failed) {
2307+ if (zip != NULL ) {
2308+ zip_close (zip);
2309+ }
2310+
2311+ if (loading_failed) {
22692312 success = false ;
22702313 break ;
22712314 }
0 commit comments