|  | 
| 6 | 6 | import pathlib | 
| 7 | 7 | import subprocess | 
| 8 | 8 | from concurrent.futures import Future, ThreadPoolExecutor | 
| 9 |  | -from typing import IO, ClassVar | 
|  | 9 | +from typing import ClassVar | 
| 10 | 10 | 
 | 
| 11 | 11 | import requests | 
| 12 | 12 | import requests.adapters | 
|  | 
| 16 | 16 | 
 | 
| 17 | 17 | from zimscraperlib import logger | 
| 18 | 18 | from zimscraperlib.constants import DEFAULT_WEB_REQUESTS_TIMEOUT | 
|  | 19 | +from zimscraperlib.typing import SupportsSeekableWrite, SupportsWrite | 
| 19 | 20 | 
 | 
| 20 | 21 | 
 | 
| 21 | 22 | class YoutubeDownloader: | 
| @@ -59,11 +60,10 @@ def download( | 
| 59 | 60 |         future = self.executor.submit(self._run_youtube_dl, url, options or {}) | 
| 60 | 61 |         if not wait: | 
| 61 | 62 |             return future | 
| 62 |  | -        if not future.exception(): | 
| 63 |  | -            # return the result | 
| 64 |  | -            return future.result()  # pyright: ignore | 
| 65 |  | -        # raise the exception | 
| 66 |  | -        raise future.exception()  # pyright: ignore | 
|  | 63 | +        exc = future.exception() | 
|  | 64 | +        if isinstance(exc, BaseException): | 
|  | 65 | +            raise exc | 
|  | 66 | +        return True | 
| 67 | 67 | 
 | 
| 68 | 68 | 
 | 
| 69 | 69 | class YoutubeConfig(dict): | 
| @@ -176,7 +176,7 @@ def get_session(max_retries: int | None = 5) -> requests.Session: | 
| 176 | 176 | def stream_file( | 
| 177 | 177 |     url: str, | 
| 178 | 178 |     fpath: pathlib.Path | None = None, | 
| 179 |  | -    byte_stream: IO[bytes] | None = None, | 
|  | 179 | +    byte_stream: SupportsWrite[bytes] | SupportsSeekableWrite[bytes] | None = None, | 
| 180 | 180 |     block_size: int | None = 1024, | 
| 181 | 181 |     proxies: dict[str, str] | None = None, | 
| 182 | 182 |     max_retries: int | None = 5, | 
| @@ -216,24 +216,25 @@ def stream_file( | 
| 216 | 216 | 
 | 
| 217 | 217 |     total_downloaded = 0 | 
| 218 | 218 |     if fpath is not None: | 
| 219 |  | -        fp = open(fpath, "wb") | 
| 220 |  | -    elif ( | 
| 221 |  | -        byte_stream is not None | 
| 222 |  | -    ):  # pragma: no branch (we use a precise condition to help type checker) | 
| 223 |  | -        fp = byte_stream | 
|  | 219 | +        fpath_handler = open(fpath, "wb") | 
|  | 220 | +    else: | 
|  | 221 | +        fpath_handler = None | 
| 224 | 222 | 
 | 
| 225 | 223 |     for data in resp.iter_content(block_size): | 
| 226 | 224 |         total_downloaded += len(data) | 
| 227 |  | -        fp.write(data) | 
|  | 225 | +        if fpath_handler: | 
|  | 226 | +            fpath_handler.write(data) | 
|  | 227 | +        if byte_stream: | 
|  | 228 | +            byte_stream.write(data) | 
| 228 | 229 | 
 | 
| 229 | 230 |         # stop downloading/reading if we're just testing first block | 
| 230 | 231 |         if only_first_block: | 
| 231 | 232 |             break | 
| 232 | 233 | 
 | 
| 233 | 234 |     logger.debug(f"Downloaded {total_downloaded} bytes from {url}") | 
| 234 | 235 | 
 | 
| 235 |  | -    if fpath: | 
| 236 |  | -        fp.close() | 
| 237 |  | -    else: | 
| 238 |  | -        fp.seek(0) | 
|  | 236 | +    if fpath_handler: | 
|  | 237 | +        fpath_handler.close() | 
|  | 238 | +    elif isinstance(byte_stream, SupportsSeekableWrite) and byte_stream.seekable(): | 
|  | 239 | +        byte_stream.seek(0) | 
| 239 | 240 |     return total_downloaded, resp.headers | 
0 commit comments