Skip to content

Commit 022e149

Browse files
committed
setup_tools refactor
1 parent 32fddd0 commit 022e149

File tree

1 file changed

+51
-27
lines changed

1 file changed

+51
-27
lines changed

.actions/setup_tools.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import glob
15-
import logging
1615
import os
1716
import pathlib
1817
import re
@@ -22,8 +21,11 @@
2221
import urllib.request
2322
from importlib.util import module_from_spec, spec_from_file_location
2423
from itertools import groupby
24+
from pathlib import Path
2525
from types import ModuleType
26-
from typing import List
26+
from typing import Any, Iterable, Iterator, List, Optional
27+
28+
import pkg_resources
2729

2830
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
2931
_PACKAGE_MAPPING = {"pytorch": "pytorch_lightning", "app": "lightning_app"}
@@ -41,33 +43,56 @@ def _load_py_module(name: str, location: str) -> ModuleType:
4143
return py
4244

4345

44-
def load_requirements(
45-
path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True
46-
) -> List[str]:
46+
class _RequirementWithComment(pkg_resources.Requirement):
47+
def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
48+
super().__init__(*args, **kwargs)
49+
self.comment = comment
50+
assert pip_argument is None or pip_argument # sanity check that it's not an empty str
51+
self.pip_argument = pip_argument
52+
self.strict = "# strict" in comment.lower()
53+
54+
def clean_str(self, unfreeze: bool) -> str:
55+
# remove version restrictions unless they are strict
56+
return self.project_name if unfreeze and not self.strict else str(self)
57+
58+
59+
def _parse_requirements(strs: Iterable) -> Iterator[_RequirementWithComment]:
60+
"""Adapted from `pkg_resources.parse_requirements` to include comments."""
61+
lines = pkg_resources.yield_lines(strs)
62+
pip_argument = None
63+
for line in lines:
64+
# Drop comments -- a hash without a space may be in a URL.
65+
if " #" in line:
66+
comment_pos = line.find(" #")
67+
line, comment = line[:comment_pos], line[comment_pos:]
68+
else:
69+
comment = ""
70+
# If there is a line continuation, drop it, and append the next line.
71+
if line.endswith("\\"):
72+
line = line[:-2].strip()
73+
try:
74+
line += next(lines)
75+
except StopIteration:
76+
return
77+
# If there's a pip argument, save it
78+
if line.startswith("--"):
79+
pip_argument = line
80+
continue
81+
yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument)
82+
pip_argument = None
83+
84+
85+
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: bool = True) -> List[str]:
4786
"""Load requirements from a file.
4887
49-
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
88+
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements", "pytorch")
5089
>>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
5190
['numpy...', 'torch...', ...]
5291
"""
53-
with open(os.path.join(path_dir, file_name)) as file:
54-
lines = [ln.strip() for ln in file.readlines()]
55-
reqs = []
56-
for ln in lines:
57-
# filer all comments
58-
comment = ""
59-
if comment_char in ln:
60-
comment = ln[ln.index(comment_char) :]
61-
ln = ln[: ln.index(comment_char)]
62-
req = ln.strip()
63-
# skip directly installed dependencies
64-
if not req or req.startswith("http") or "@http" in req:
65-
continue
66-
# remove version restrictions unless they are strict
67-
if unfreeze and "<" in req and "strict" not in comment:
68-
req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip()
69-
reqs.append(req)
70-
return reqs
92+
path = Path(path_dir) / file_name
93+
assert path.exists(), (path_dir, file_name, path)
94+
text = path.read_text()
95+
return [req.clean_str(unfreeze) for req in _parse_requirements(text)]
7196

7297

7398
def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
@@ -294,9 +319,8 @@ class implementations by cross-imports to the true package.
294319
if fname in ("__about__.py", "__version__.py"):
295320
body = lines
296321
else:
297-
if fname.startswith("_") and fname not in ("__init__.py", "__main__.py"):
298-
logging.warning(f"unsupported file: {local_path}")
299-
continue
322+
if fname.startswith("_") and fname not in ("__init__.py", "__main__.py", "__setup__.py"):
323+
raise ValueError(f"Unsupported file: {fname}")
300324
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
301325
body = prune_comments_docstrings(lines)
302326
if fname not in ("__init__.py", "__main__.py"):

0 commit comments

Comments
 (0)