Skip to content

Commit 111007f

Browse files
committed
download progress
1 parent 457a343 commit 111007f

File tree

1 file changed

+53
-4
lines changed

1 file changed

+53
-4
lines changed

whispercpppy/model.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,30 +95,79 @@ def stream_download(
9595
timeout: float | None = None,
9696
) -> None:
9797
tmp = savepath.with_suffix(savepath.suffix + ".tmp")
98+
progress_printed = False
9899
try:
99100
with requests.get(url, stream=True, timeout=timeout) as r:
100101
r.raise_for_status()
102+
total = int(r.headers.get("Content-Length", 0) or 0)
103+
chunk_size = 64 * 1024
104+
report_every = max(total // 100, chunk_size) if total else 1_000_000
105+
next_report = report_every
106+
downloaded = 0
107+
mb = 1_000_000
108+
101109
with open(tmp, "wb") as f:
102-
for chunk in r.iter_content():
103-
if chunk:
104-
f.write(chunk)
110+
for chunk in r.iter_content(chunk_size=chunk_size):
111+
if not chunk:
112+
continue
113+
f.write(chunk)
114+
downloaded += len(chunk)
115+
116+
if total:
117+
if downloaded >= next_report or downloaded >= total:
118+
percent = min(int(downloaded * 100 / total), 100)
119+
current_mb = downloaded / mb
120+
total_mb = total / mb
121+
print(
122+
f" {percent:3d}% ({current_mb:.1f}/{total_mb:.1f} MB)",
123+
end="\r",
124+
flush=True,
125+
)
126+
progress_printed = True
127+
next_report = min(total, downloaded + report_every)
128+
else:
129+
if downloaded >= next_report:
130+
current_mb = downloaded / mb
131+
print(
132+
f" downloaded {current_mb:.1f} MB",
133+
end="\r",
134+
flush=True,
135+
)
136+
progress_printed = True
137+
next_report = downloaded + report_every
138+
139+
if total:
140+
total_mb = total / mb
141+
downloaded_mb = downloaded / mb
142+
print(
143+
f" 100% ({downloaded_mb:.1f}/{total_mb:.1f} MB)",
144+
flush=True,
145+
)
146+
else:
147+
downloaded_mb = downloaded / mb
148+
print(f" downloaded {downloaded_mb:.1f} MB", flush=True)
105149
tmp.replace(savepath)
106150
except Exception:
151+
if progress_printed:
152+
print()
107153
if tmp.exists():
108154
tmp.unlink(missing_ok=True)
109155
raise
110156

111157

112158
def download_model(
113159
model: str,
114-
models_dir: Path | None = None,
160+
models_dir: Path | str | None = None,
115161
overwrite: bool = False,
116162
timeout: float | None = None,
117163
) -> DownloadResult:
164+
if isinstance(models_dir, str):
165+
models_dir = Path(models_dir)
118166
url, savepath = prepare_download(model, models_dir)
119167
existed = savepath.is_file()
120168
if existed and not overwrite:
121169
return DownloadResult(model=model, url=url, dest=savepath, existed=True)
170+
print(f"downloading {model} to {savepath.resolve()}")
122171
stream_download(url, savepath, timeout=timeout)
123172
return DownloadResult(model=model, url=url, dest=savepath, existed=existed)
124173

0 commit comments

Comments
 (0)