|
3 | 3 | import os |
4 | 4 | from functools import partial |
5 | 5 | from pathlib import Path |
6 | | -from typing import Union |
| 6 | +from tempfile import TemporaryDirectory |
| 7 | +from typing import Optional, Union |
7 | 8 |
|
8 | 9 | import torch |
9 | 10 | from torch.hub import HASH_REGEX, download_url_to_file, urlparse |
| 11 | + |
10 | 12 | try: |
11 | 13 | from torch.hub import get_dir |
12 | 14 | except ImportError: |
|
15 | 17 | from timm import __version__ |
16 | 18 |
|
17 | 19 | try: |
18 | | - from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url |
| 20 | + from huggingface_hub import (create_repo, get_hf_file_metadata, |
| 21 | + hf_hub_download, hf_hub_url, |
| 22 | + repo_type_and_id_from_hf_id, upload_folder) |
| 23 | + from huggingface_hub.utils import EntryNotFoundError |
19 | 24 | hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) |
20 | 25 | _has_hf_hub = True |
21 | 26 | except ImportError: |
@@ -121,56 +126,45 @@ def save_for_hf(model, save_directory, model_config=None): |
121 | 126 |
|
122 | 127 | def push_to_hf_hub( |
123 | 128 | model, |
124 | | - local_dir, |
125 | | - repo_namespace_or_url=None, |
126 | | - commit_message='Add model', |
127 | | - use_auth_token=True, |
128 | | - git_email=None, |
129 | | - git_user=None, |
130 | | - revision=None, |
131 | | - model_config=None, |
| 129 | + repo_id: str, |
| 130 | + commit_message: str ='Add model', |
| 131 | + token: Optional[str] = None, |
| 132 | + revision: Optional[str] = None, |
| 133 | + private: bool = False, |
| 134 | + create_pr: bool = False, |
| 135 | + model_config: Optional[dict] = None, |
132 | 136 | ): |
133 | | - if isinstance(use_auth_token, str): |
134 | | - token = use_auth_token |
135 | | - else: |
136 | | - token = HfFolder.get_token() |
137 | | - if token is None: |
138 | | - raise ValueError( |
139 | | - "You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and " |
140 | | - "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " |
141 | | - "token as the `use_auth_token` argument." |
142 | | - ) |
143 | | - |
144 | | - if repo_namespace_or_url: |
145 | | - repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:] |
146 | | - else: |
147 | | - repo_owner = HfApi().whoami(token)['name'] |
148 | | - repo_name = Path(local_dir).name |
149 | | - |
150 | | - repo_id = f'{repo_owner}/{repo_name}' |
151 | | - repo_url = f'https://huggingface.co/{repo_id}' |
152 | | - |
153 | 137 | # Create repo if doesn't exist yet |
154 | | - HfApi().create_repo(repo_id, token=use_auth_token, exist_ok=True) |
155 | | - |
156 | | - repo = Repository( |
157 | | - local_dir, |
158 | | - clone_from=repo_url, |
159 | | - use_auth_token=use_auth_token, |
160 | | - git_user=git_user, |
161 | | - git_email=git_email, |
162 | | - revision=revision, |
163 | | - ) |
164 | | - |
165 | | - # Prepare a default model card that includes the necessary tags to enable inference. |
166 | | - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' |
167 | | - with repo.commit(commit_message): |
| 138 | + repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) |
| 139 | + |
| 140 | + # Infer complete repo_id from repo_url |
| 141 | + # Can be different from the input `repo_id` if repo_owner was implicit |
| 142 | + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) |
| 143 | + repo_id = f"{repo_owner}/{repo_name}" |
| 144 | + |
| 145 | + # Check if README file already exist in repo |
| 146 | + try: |
| 147 | + get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) |
| 148 | + has_readme = True |
| 149 | + except EntryNotFoundError: |
| 150 | + has_readme = False |
| 151 | + |
| 152 | + # Dump model and push to Hub |
| 153 | + with TemporaryDirectory() as tmpdir: |
168 | 154 | # Save model weights and config. |
169 | | - save_for_hf(model, repo.local_dir, model_config=model_config) |
| 155 | + save_for_hf(model, tmpdir, model_config=model_config) |
170 | 156 |
|
171 | | - # Save a model card if it doesn't exist. |
172 | | - readme_path = Path(repo.local_dir) / 'README.md' |
173 | | - if not readme_path.exists(): |
| 157 | + # Add readme if does not exist |
| 158 | + if not has_readme: |
| 159 | + readme_path = Path(tmpdir) / "README.md" |
| 160 | + readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}' |
174 | 161 | readme_path.write_text(readme_text) |
175 | 162 |
|
176 | | - return repo.git_remote_url() |
| 163 | + # Upload model and return |
| 164 | + return upload_folder( |
| 165 | + repo_id=repo_id, |
| 166 | + folder_path=tmpdir, |
| 167 | + revision=revision, |
| 168 | + create_pr=create_pr, |
| 169 | + commit_message=commit_message, |
| 170 | + ) |
0 commit comments