Skip to content

Commit b40b2ed

Browse files
authored
Improve requirements parser (#13912)
1 parent c2c363d commit b40b2ed

File tree

1 file changed

+97
-78
lines changed

1 file changed

+97
-78
lines changed

.actions/assistant.py

Lines changed: 97 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
from itertools import chain
2424
from os.path import dirname, isfile
2525
from pathlib import Path
26-
from typing import Dict, List, Optional, Sequence, Tuple
26+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
2727

28-
from pkg_resources import parse_requirements
28+
from pkg_resources import parse_requirements, Requirement, yield_lines
2929

3030
REQUIREMENT_FILES = {
3131
"pytorch": (
@@ -49,86 +49,106 @@
4949
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
5050

5151

52-
def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: str = "all") -> str:
53-
"""Adjust the upper version contrains.
54-
55-
Args:
56-
ln: raw line from requirement
57-
comment_char: charter marking comment
58-
unfreeze: Enum or "all"|"major"|""
59-
60-
Returns:
61-
adjusted requirement
62-
63-
>>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # anything", unfreeze="none")
64-
'arrow<=1.2.2,>=1.2.0'
65-
>>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # strict", unfreeze="none")
66-
'arrow<=1.2.2,>=1.2.0 # strict'
67-
>>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # my name", unfreeze="all")
68-
'arrow>=1.2.0'
69-
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="all")
70-
'arrow>=1.2.0, <=1.2.2 # strict'
71-
>>> _augment_requirement("arrow", unfreeze="all")
72-
'arrow'
73-
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # cool", unfreeze="major")
74-
'arrow>=1.2.0, <2.0 # strict'
75-
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="major")
76-
'arrow>=1.2.0, <=1.2.2 # strict'
77-
>>> _augment_requirement("arrow>=1.2.0", unfreeze="major")
78-
'arrow>=1.2.0, <2.0 # strict'
79-
>>> _augment_requirement("arrow", unfreeze="major")
80-
'arrow'
52+
class _RequirementWithComment(Requirement):
53+
strict_string = "# strict"
54+
55+
def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None:
56+
super().__init__(*args, **kwargs)
57+
self.comment = comment
58+
assert pip_argument is None or pip_argument # sanity check that it's not an empty str
59+
self.pip_argument = pip_argument
60+
self.strict = self.strict_string in comment.lower()
61+
62+
def adjust(self, unfreeze: str) -> str:
63+
"""Remove version restrictions unless they are strict.
64+
65+
>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# anything").adjust("none")
66+
'arrow<=1.2.2,>=1.2.0'
67+
>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# strict").adjust("none")
68+
'arrow<=1.2.2,>=1.2.0 # strict'
69+
>>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# my name").adjust("all")
70+
'arrow>=1.2.0'
71+
>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("all")
72+
'arrow<=1.2.2,>=1.2.0 # strict'
73+
>>> _RequirementWithComment("arrow").adjust("all")
74+
'arrow'
75+
>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# cool").adjust("major")
76+
'arrow<2.0,>=1.2.0'
77+
>>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major")
78+
'arrow<=1.2.2,>=1.2.0 # strict'
79+
>>> _RequirementWithComment("arrow>=1.2.0").adjust("major")
80+
'arrow>=1.2.0'
81+
>>> _RequirementWithComment("arrow").adjust("major")
82+
'arrow'
83+
"""
84+
out = str(self)
85+
if self.strict:
86+
return f"{out} {self.strict_string}"
87+
if unfreeze == "major":
88+
for operator, version in self.specs:
89+
if operator in ("<", "<="):
90+
major = LooseVersion(version).version[0]
91+
# replace upper bound with major version increased by one
92+
return out.replace(f"{operator}{version}", f"<{major + 1}.0")
93+
elif unfreeze == "all":
94+
for operator, version in self.specs:
95+
if operator in ("<", "<="):
96+
# drop upper bound
97+
return out.replace(f"{operator}{version},", "")
98+
elif unfreeze != "none":
99+
raise ValueError(f"Unexpected unfreeze: {unfreeze!r} value.")
100+
return out
101+
102+
103+
def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]:
104+
"""Adapted from `pkg_resources.parse_requirements` to include comments.
105+
106+
>>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt']
107+
>>> [r.adjust('none') for r in _parse_requirements(txt)]
108+
['this', 'example', 'foo # strict', 'thing']
109+
>>> txt = '\\n'.join(txt)
110+
>>> [r.adjust('none') for r in _parse_requirements(txt)]
111+
['this', 'example', 'foo # strict', 'thing']
81112
"""
82-
assert unfreeze in {"none", "major", "all"}
83-
# filer all comments
84-
if comment_char in ln:
85-
comment = ln[ln.index(comment_char) :]
86-
ln = ln[: ln.index(comment_char)]
87-
is_strict = "strict" in comment
88-
else:
89-
is_strict = False
90-
req = ln.strip()
91-
# skip directly installed dependencies
92-
if not req or any(c in req for c in ["http:", "https:", "@"]):
93-
return ""
94-
# extract the major version from all listed versions
95-
if unfreeze == "major":
96-
req_ = list(parse_requirements([req]))[0]
97-
vers = [LooseVersion(v) for s, v in req_.specs if s not in ("==", "~=")]
98-
ver_major = sorted(vers)[-1].version[0] if vers else None
99-
else:
100-
ver_major = None
101-
102-
# remove version restrictions unless they are strict
103-
if unfreeze != "none" and "<" in req and not is_strict:
104-
req = re.sub(r",? *<=? *[\d\.\*]+,? *", "", req).strip()
105-
if ver_major is not None and not is_strict:
106-
# add , only if there are already some versions
107-
req += f"{',' if any(c in req for c in '<=>') else ''} <{int(ver_major) + 1}.0"
108-
109-
# adding strict back to the comment
110-
if is_strict or ver_major is not None:
111-
req += " # strict"
112-
113-
return req
114-
115-
116-
def load_requirements(
117-
path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: str = "all"
118-
) -> List[str]:
113+
lines = yield_lines(strs)
114+
pip_argument = None
115+
for line in lines:
116+
# Drop comments -- a hash without a space may be in a URL.
117+
if " #" in line:
118+
comment_pos = line.find(" #")
119+
line, comment = line[:comment_pos], line[comment_pos:]
120+
else:
121+
comment = ""
122+
# If there is a line continuation, drop it, and append the next line.
123+
if line.endswith("\\"):
124+
line = line[:-2].strip()
125+
try:
126+
line += next(lines)
127+
except StopIteration:
128+
return
129+
# If there's a pip argument, save it
130+
if line.startswith("--"):
131+
pip_argument = line
132+
continue
133+
if line.startswith("-r "):
134+
# linked requirement files are unsupported
135+
continue
136+
yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument)
137+
pip_argument = None
138+
139+
140+
def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]:
119141
"""Loading requirements from a file.
120142
121143
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
122144
>>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
123-
['sphinx>=4.0, <6.0 # strict', ...]
145+
['sphinx<6.0,>=4.0', ...]
124146
"""
125147
assert unfreeze in {"none", "major", "all"}
126-
with open(os.path.join(path_dir, file_name)) as file:
127-
lines = [ln.strip() for ln in file.readlines()]
128-
reqs = [_augment_requirement(ln, comment_char=comment_char, unfreeze=unfreeze) for ln in lines]
129-
# filter empty lines and containing @ which means redirect to some git/http
130-
reqs = [str(req) for req in reqs if req and not any(c in req for c in ["@", "http:", "https:"])]
131-
return reqs
148+
path = Path(path_dir) / file_name
149+
assert path.exists(), (path_dir, file_name, path)
150+
text = path.read_text()
151+
return [req.adjust(unfreeze) for req in _parse_requirements(text)]
132152

133153

134154
def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
@@ -213,14 +233,13 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
213233
>>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements"))
214234
"""
215235
requires = [
216-
# TODO: consider passing unfreeze as string instead
217-
load_requirements(d, file_name="base.txt", unfreeze="none" if freeze_requirements else "major")
236+
load_requirements(d, unfreeze="none" if freeze_requirements else "major")
218237
for d in glob.glob(os.path.join(req_dir, "*"))
219238
# skip empty folder as git artefacts, and resolving Will's special issue
220239
if os.path.isdir(d) and len(glob.glob(os.path.join(d, "*"))) > 0 and "__pycache__" not in d
221240
]
222241
if not requires:
223-
return None
242+
return
224243
# TODO: add some smarter version aggregation per each package
225244
requires = sorted(set(chain(*requires)))
226245
with open(os.path.join(req_dir, "base.txt"), "w") as fp:

0 commit comments

Comments
 (0)