1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import glob
15- import logging
1615import os
1716import pathlib
1817import re
2221import urllib .request
2322from importlib .util import module_from_spec , spec_from_file_location
2423from itertools import groupby
24+ from pathlib import Path
2525from types import ModuleType
26- from typing import List
26+ from typing import Any , Iterable , Iterator , List , Optional
27+
28+ import pkg_resources
2729
2830_PROJECT_ROOT = os .path .dirname (os .path .dirname (__file__ ))
2931_PACKAGE_MAPPING = {"pytorch" : "pytorch_lightning" , "app" : "lightning_app" }
@@ -41,33 +43,56 @@ def _load_py_module(name: str, location: str) -> ModuleType:
4143 return py
4244
4345
44- def load_requirements (
45- path_dir : str , file_name : str = "base.txt" , comment_char : str = "#" , unfreeze : bool = True
46- ) -> List [str ]:
46+ class _RequirementWithComment (pkg_resources .Requirement ):
47+ def __init__ (self , * args : Any , comment : str = "" , pip_argument : Optional [str ] = None , ** kwargs : Any ) -> None :
48+ super ().__init__ (* args , ** kwargs )
49+ self .comment = comment
50+ assert pip_argument is None or pip_argument # sanity check that it's not an empty str
51+ self .pip_argument = pip_argument
52+ self .strict = "# strict" in comment .lower ()
53+
54+ def clean_str (self , unfreeze : bool ) -> str :
55+ # remove version restrictions unless they are strict
56+ return self .project_name if unfreeze and not self .strict else str (self )
57+
58+
59+ def _parse_requirements (strs : Iterable ) -> Iterator [_RequirementWithComment ]:
60+ """Adapted from `pkg_resources.parse_requirements` to include comments."""
61+ lines = pkg_resources .yield_lines (strs )
62+ pip_argument = None
63+ for line in lines :
64+ # Drop comments -- a hash without a space may be in a URL.
65+ if " #" in line :
66+ comment_pos = line .find (" #" )
67+ line , comment = line [:comment_pos ], line [comment_pos :]
68+ else :
69+ comment = ""
70+ # If there is a line continuation, drop it, and append the next line.
71+ if line .endswith ("\\ " ):
72+ line = line [:- 2 ].strip ()
73+ try :
74+ line += next (lines )
75+ except StopIteration :
76+ return
77+ # If there's a pip argument, save it
78+ if line .startswith ("--" ):
79+ pip_argument = line
80+ continue
81+ yield _RequirementWithComment (line , comment = comment , pip_argument = pip_argument )
82+ pip_argument = None
83+
84+
85+ def load_requirements (path_dir : str , file_name : str = "base.txt" , unfreeze : bool = True ) -> List [str ]:
4786 """Load requirements from a file.
4887
49- >>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
88+ >>> path_req = os.path.join(_PROJECT_ROOT, "requirements", "pytorch" )
5089 >>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
5190 ['numpy...', 'torch...', ...]
5291 """
53- with open (os .path .join (path_dir , file_name )) as file :
54- lines = [ln .strip () for ln in file .readlines ()]
55- reqs = []
56- for ln in lines :
57- # filer all comments
58- comment = ""
59- if comment_char in ln :
60- comment = ln [ln .index (comment_char ) :]
61- ln = ln [: ln .index (comment_char )]
62- req = ln .strip ()
63- # skip directly installed dependencies
64- if not req or req .startswith ("http" ) or "@http" in req :
65- continue
66- # remove version restrictions unless they are strict
67- if unfreeze and "<" in req and "strict" not in comment :
68- req = re .sub (r",? *<=? *[\d\.\*]+" , "" , req ).strip ()
69- reqs .append (req )
70- return reqs
92+ path = Path (path_dir ) / file_name
93+ assert path .exists (), (path_dir , file_name , path )
94+ text = path .read_text ()
95+ return [req .clean_str (unfreeze ) for req in _parse_requirements (text )]
7196
7297
7398def load_readme_description (path_dir : str , homepage : str , version : str ) -> str :
@@ -294,9 +319,8 @@ class implementations by cross-imports to the true package.
294319 if fname in ("__about__.py" , "__version__.py" ):
295320 body = lines
296321 else :
297- if fname .startswith ("_" ) and fname not in ("__init__.py" , "__main__.py" ):
298- logging .warning (f"unsupported file: { local_path } " )
299- continue
322+ if fname .startswith ("_" ) and fname not in ("__init__.py" , "__main__.py" , "__setup__.py" ):
323+ raise ValueError (f"Unsupported file: { fname } " )
300324 # ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
301325 body = prune_comments_docstrings (lines )
302326 if fname not in ("__init__.py" , "__main__.py" ):
0 commit comments