2121import tempfile
2222import urllib .request
2323from datetime import datetime
24+ from distutils .version import LooseVersion
2425from importlib .util import module_from_spec , spec_from_file_location
2526from itertools import chain , groupby
2627from types import ModuleType
2728from typing import List
2829
30+ from pkg_resources import parse_requirements
31+
2932_PROJECT_ROOT = os .path .dirname (os .path .dirname (__file__ ))
3033_PACKAGE_MAPPING = {"pytorch" : "pytorch_lightning" , "app" : "lightning_app" }
3134
@@ -42,45 +45,92 @@ def _load_py_module(name: str, location: str) -> ModuleType:
4245 return py
4346
4447
48+ def _augment_requirement (ln : str , comment_char : str = "#" , unfreeze : str = "all" ) -> str :
49+ """Adjust the upper version contrains.
50+
51+ Args:
52+ ln: raw line from requirement
53+ comment_char: charter marking comment
54+ unfreeze: Enum or "all"|"major"|""
55+
56+ Returns:
57+ adjusted requirement
58+
59+ >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # anything", unfreeze="")
60+ 'arrow>=1.2.0, <=1.2.2'
61+ >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="")
62+ 'arrow>=1.2.0, <=1.2.2 # strict'
63+ >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # my name", unfreeze="all")
64+ 'arrow>=1.2.0'
65+ >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="all")
66+ 'arrow>=1.2.0, <=1.2.2 # strict'
67+ >>> _augment_requirement("arrow", unfreeze="all")
68+ 'arrow'
69+ >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # cool", unfreeze="major")
70+ 'arrow>=1.2.0, <2.0 # strict'
71+ >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="major")
72+ 'arrow>=1.2.0, <=1.2.2 # strict'
73+ >>> _augment_requirement("arrow>=1.2.0", unfreeze="major")
74+ 'arrow>=1.2.0, <2.0 # strict'
75+ >>> _augment_requirement("arrow", unfreeze="major")
76+ 'arrow'
77+ """
78+ # filer all comments
79+ if comment_char in ln :
80+ comment = ln [ln .index (comment_char ) :]
81+ ln = ln [: ln .index (comment_char )]
82+ is_strict = "strict" in comment
83+ else :
84+ is_strict = False
85+ req = ln .strip ()
86+ # skip directly installed dependencies
87+ if not req or req .startswith ("http" ) or "@http" in req :
88+ return ""
89+ # extract the major version from all listed versions
90+ if unfreeze == "major" :
91+ req_ = list (parse_requirements ([req ]))[0 ]
92+ vers = [LooseVersion (v ) for s , v in req_ .specs if s not in ("==" , "~=" )]
93+ ver_major = sorted (vers )[- 1 ].version [0 ] if vers else None
94+ else :
95+ ver_major = None
96+
97+ # remove version restrictions unless they are strict
98+ if unfreeze and "<" in req and not is_strict :
99+ req = re .sub (r",? *<=? *[\d\.\*]+" , "" , req ).strip ()
100+ if ver_major is not None and not is_strict :
101+ # add , only if there are already some versions
102+ req += f"{ ',' if any (c in req for c in '<=>' ) else '' } <{ int (ver_major ) + 1 } .0"
103+
104+ # adding strict back to the comment
105+ if is_strict or ver_major is not None :
106+ req += " # strict"
107+
108+ return req
109+
110+
45111def load_requirements (
46- path_dir : str , file_name : str = "base.txt" , comment_char : str = "#" , unfreeze : bool = True
112+ path_dir : str , file_name : str = "base.txt" , comment_char : str = "#" , unfreeze : str = "all"
47113) -> List [str ]:
48114 """Loading requirements from a file.
49115
50116 >>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
51- >>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
52- ['numpy ...', 'torch ...', ... ]
117+ >>> load_requirements(path_req, unfreeze="major" ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
118+ ['pytorch_lightning ...', 'lightning_app ...']
53119 """
54120 with open (os .path .join (path_dir , file_name )) as file :
55121 lines = [ln .strip () for ln in file .readlines ()]
56122 reqs = []
57123 for ln in lines :
58- # filer all comments
59- comment = ""
60- if comment_char in ln :
61- comment = ln [ln .index (comment_char ) :]
62- ln = ln [: ln .index (comment_char )]
63- req = ln .strip ()
64- # skip directly installed dependencies
65- if not req or req .startswith ("http" ) or "@http" in req :
66- continue
67- # remove version restrictions unless they are strict
68- if unfreeze and "<" in req and "strict" not in comment :
69- req = re .sub (r",? *<=? *[\d\.\*]+" , "" , req ).strip ()
70-
71- # adding strict back to the comment
72- if "strict" in comment :
73- req += " # strict"
74-
75- reqs .append (req )
76- return reqs
124+ reqs .append (_augment_requirement (ln , comment_char = comment_char , unfreeze = unfreeze ))
125+ # filter empty lines
126+ return [str (req ) for req in reqs if req ]
77127
78128
79129def load_readme_description (path_dir : str , homepage : str , version : str ) -> str :
80130 """Load readme as decribtion.
81131
82132 >>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
83- '<div align="center"> ...'
133+ '...'
84134 """
85135 path_readme = os .path .join (path_dir , "README.md" )
86136 text = open (path_readme , encoding = "utf-8" ).read ()
@@ -439,12 +489,14 @@ def _download_frontend(root: str = _PROJECT_ROOT):
439489 print ("The Lightning UI downloading has failed!" )
440490
441491
442- def _adjust_require_versions (source_dir : str = "src" , req_dir : str = "requirements" ) -> None :
443- """Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`."""
492+ def _relax_require_versions (source_dir : str = "src" , req_dir : str = "requirements" ) -> None :
493+ """Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`.
494+
495+ >>> _relax_require_versions("../src", "../requirements")
496+ """
444497 reqs = load_requirements (req_dir , file_name = "base.txt" )
445- for i , req in enumerate (reqs ):
446- pkg_name = req [: min (req .index (c ) for c in ">=" if c in req )]
447- ver_ = parse_version_from_file (os .path .join (source_dir , pkg_name ))
498+ for i , req in enumerate (parse_requirements (reqs )):
499+ ver_ = parse_version_from_file (os .path .join (source_dir , req .name ))
448500 if not ver_ :
449501 continue
450502 ver2 = "." .join (ver_ .split ("." )[:2 ] + ["*" ])
0 commit comments