@@ -95,30 +95,79 @@ def stream_download(
9595 timeout : float | None = None ,
9696) -> None :
9797 tmp = savepath .with_suffix (savepath .suffix + ".tmp" )
98+ progress_printed = False
9899 try :
99100 with requests .get (url , stream = True , timeout = timeout ) as r :
100101 r .raise_for_status ()
102+ total = int (r .headers .get ("Content-Length" , 0 ) or 0 )
103+ chunk_size = 64 * 1024
104+ report_every = max (total // 100 , chunk_size ) if total else 1_000_000
105+ next_report = report_every
106+ downloaded = 0
107+ mb = 1_000_000
108+
101109 with open (tmp , "wb" ) as f :
102- for chunk in r .iter_content ():
103- if chunk :
104- f .write (chunk )
110+ for chunk in r .iter_content (chunk_size = chunk_size ):
111+ if not chunk :
112+ continue
113+ f .write (chunk )
114+ downloaded += len (chunk )
115+
116+ if total :
117+ if downloaded >= next_report or downloaded >= total :
118+ percent = min (int (downloaded * 100 / total ), 100 )
119+ current_mb = downloaded / mb
120+ total_mb = total / mb
121+ print (
122+ f" { percent :3d} % ({ current_mb :.1f} /{ total_mb :.1f} MB)" ,
123+ end = "\r " ,
124+ flush = True ,
125+ )
126+ progress_printed = True
127+ next_report = min (total , downloaded + report_every )
128+ else :
129+ if downloaded >= next_report :
130+ current_mb = downloaded / mb
131+ print (
132+ f" downloaded { current_mb :.1f} MB" ,
133+ end = "\r " ,
134+ flush = True ,
135+ )
136+ progress_printed = True
137+ next_report = downloaded + report_every
138+
139+ if total :
140+ total_mb = total / mb
141+ downloaded_mb = downloaded / mb
142+ print (
143+ f" 100% ({ downloaded_mb :.1f} /{ total_mb :.1f} MB)" ,
144+ flush = True ,
145+ )
146+ else :
147+ downloaded_mb = downloaded / mb
148+ print (f" downloaded { downloaded_mb :.1f} MB" , flush = True )
105149 tmp .replace (savepath )
106150 except Exception :
151+ if progress_printed :
152+ print ()
107153 if tmp .exists ():
108154 tmp .unlink (missing_ok = True )
109155 raise
110156
111157
112158def download_model (
113159 model : str ,
114- models_dir : Path | None = None ,
160+ models_dir : Path | str | None = None ,
115161 overwrite : bool = False ,
116162 timeout : float | None = None ,
117163) -> DownloadResult :
164+ if isinstance (models_dir , str ):
165+ models_dir = Path (models_dir )
118166 url , savepath = prepare_download (model , models_dir )
119167 existed = savepath .is_file ()
120168 if existed and not overwrite :
121169 return DownloadResult (model = model , url = url , dest = savepath , existed = True )
170+ print (f"downloading { model } to { savepath .resolve ()} " )
122171 stream_download (url , savepath , timeout = timeout )
123172 return DownloadResult (model = model , url = url , dest = savepath , existed = existed )
124173
0 commit comments