- 
                Notifications
    You must be signed in to change notification settings 
- Fork 7.2k
port special tests from CircleCI to GHA #7396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1c8fdf9
              18fbc1c
              0daaf9c
              f800fa5
              19b7607
              7e94344
              5d26aad
              5d6f391
              933a78b
              4909c84
              85e0b08
              2b94a36
              3cafef2
              3a3b300
              e670854
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| #!/usr/bin/env bash | ||
|  | ||
| set -euo pipefail | ||
|  | ||
| ./.github/scripts/setup-env.sh | ||
|  | ||
| # Prepare conda | ||
| CONDA_PATH=$(which conda) | ||
| eval "$(${CONDA_PATH} shell.bash hook)" | ||
| conda activate ci | ||
|  | ||
| echo '::group::Install testing utilities' | ||
| pip install --progress-bar=off pytest pytest-mock pytest-cov | ||
| echo '::endgroup::' | ||
|  | ||
| echo '::group::Run unittests' | ||
| pytest --durations=25 | ||
| echo '::endgroup::' | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,41 @@ | ||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file is a implementation of Lines 171 to 189 in 5850f37 
 in Python. The old version relied on  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One difference is that this PR uses async downloads, while the old version used multiprocessing. It seems async is roughly 5x slower: 
 I'll try multiprocessing and see if this actually is the root cause or this just comes from the environment change between CircleCI and GHA. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've tried multiprocessing with threads in 5d6f391. The run aborted to a  Thus, I would go with the async solution since that worked. I'm no expert in async / multiprocessing though. If someone sees possible perf improvements for either implementations, feel free to suggest. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've tried the solution with  Meaning, I'm totally fine using the async solution. | ||||||||||||||||||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||
| from time import perf_counter | ||||||||||||||||||||||||||||||||||||||||
| from urllib.parse import urlsplit | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
| import aiofiles | ||||||||||||||||||||||||||||||||||||||||
| import aiohttp | ||||||||||||||||||||||||||||||||||||||||
| from torchvision import models | ||||||||||||||||||||||||||||||||||||||||
| from tqdm.asyncio import tqdm | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
| async def main(download_root): | ||||||||||||||||||||||||||||||||||||||||
| download_root.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||||||||||||
| urls = {weight.url for name in models.list_models() for weight in iter(models.get_model_weights(name))} | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
| async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) as session: | ||||||||||||||||||||||||||||||||||||||||
| await tqdm.gather(*[download(download_root, session, url) for url in urls]) | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
| async def download(download_root, session, url): | ||||||||||||||||||||||||||||||||||||||||
| response = await session.get(url, params=dict(source="ci")) | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
| assert response.ok | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
| file_name = Path(urlsplit(url).path).name | ||||||||||||||||||||||||||||||||||||||||
| async with aiofiles.open(download_root / file_name, "wb") as f: | ||||||||||||||||||||||||||||||||||||||||
| async for data in response.content.iter_any(): | ||||||||||||||||||||||||||||||||||||||||
| await f.write(data) | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||
| download_root = ( | ||||||||||||||||||||||||||||||||||||||||
| (Path(sys.argv[1]) if len(sys.argv) > 1 else Path("~/.cache/torch/hub/checkpoints")).expanduser().resolve() | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| print(f"Downloading model weights to {download_root}") | ||||||||||||||||||||||||||||||||||||||||
| start = perf_counter() | ||||||||||||||||||||||||||||||||||||||||
| asyncio.get_event_loop().run_until_complete(main(download_root)) | ||||||||||||||||||||||||||||||||||||||||
| stop = perf_counter() | ||||||||||||||||||||||||||||||||||||||||
| minutes, seconds = divmod(stop - start, 60) | ||||||||||||||||||||||||||||||||||||||||
| print(f"Download took {minutes:2.0f}m {seconds:2.0f}s") | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@osalpekar #7189 (comment) becomes even more relevant now. Without it, we need to repeat the top two lines everywhere. I'll get on it.