1414import glob
1515import logging
1616import os
17+ import pathlib
1718import re
19+ import shutil
20+ import tarfile
21+ import tempfile
22+ import urllib .request
23+ from datetime import datetime
1824from importlib .util import module_from_spec , spec_from_file_location
19- from itertools import groupby
25+ from itertools import chain , groupby
2026from types import ModuleType
2127from typing import List
2228
2329_PROJECT_ROOT = os .path .dirname (os .path .dirname (__file__ ))
2430_PACKAGE_MAPPING = {"pytorch" : "pytorch_lightning" , "app" : "lightning_app" }
2531
32+ # TODO: remove this once lightning-ui package is ready as a dependency
33+ _LIGHTNING_FRONTEND_RELEASE_URL = "https://storage.googleapis.com/grid-packages/lightning-ui/v0.0.0/build.tar.gz"
34+
2635
2736def _load_py_module (name : str , location : str ) -> ModuleType :
2837 spec = spec_from_file_location (name , location )
@@ -36,7 +45,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
3645def load_requirements (
3746 path_dir : str , file_name : str = "base.txt" , comment_char : str = "#" , unfreeze : bool = True
3847) -> List [str ]:
39- """Load requirements from a file.
48+ """Loading requirements from a file.
4049
4150 >>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
4251 >>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
@@ -142,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
142151 ... lines = [ln.rstrip() for ln in fp.readlines()]
143152 >>> lines = replace_vars_with_imports(lines, import_path)
144153 """
154+ copied = []
145155 body , tracking , skip_offset = [], False , 0
146156 for ln in lines :
147157 offset = len (ln ) - len (ln .lstrip ())
@@ -152,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
152162 if var :
153163 name = var .groups ()[0 ]
154164 # skip private or apply white-list for allowed vars
155- if not name .startswith ("__" ) or name in ("__all__" ,):
165+ if name not in copied and ( not name .startswith ("__" ) or name in ("__all__" ,) ):
156166 body .append (f"{ ' ' * offset } from { import_path } import { name } # noqa: F401" )
167+ copied .append (name )
157168 tracking , skip_offset = True , offset
158169 continue
159170 if not tracking :
@@ -188,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
188199 return body
189200
190201
202+ def prune_func_calls (lines : List [str ]) -> List [str ]:
203+ """Prune calling functions from a file, even multi-line.
204+
205+ >>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
206+ >>> import_path = ".".join(["pytorch_lightning", "loggers"])
207+ >>> with open(py_file, encoding="utf-8") as fp:
208+ ... lines = [ln.rstrip() for ln in fp.readlines()]
209+ >>> lines = prune_func_calls(lines)
210+ """
211+ body , tracking , score = [], False , 0
212+ for ln in lines :
213+ # catching callable
214+ calling = re .match (r"^@?[\w_\d\.]+ *\(" , ln .lstrip ())
215+ if calling and " import " not in ln :
216+ tracking = True
217+ score = 0
218+ if tracking :
219+ score += ln .count ("(" ) - ln .count (")" )
220+ if score == 0 :
221+ tracking = False
222+ else :
223+ body .append (ln )
224+ return body
225+
226+
191227def prune_empty_statements (lines : List [str ]) -> List [str ]:
192228 """Prune emprty if/else and try/except.
193229
@@ -262,6 +298,46 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
262298 return body
263299
264300
301+ def wrap_try_except (body : List [str ], pkg : str , ver : str ) -> List [str ]:
302+ """Wrap the file with try/except for better traceability of import misalignment."""
303+ not_empty = sum (1 for ln in body if ln )
304+ if not_empty == 0 :
305+ return body
306+ body = ["try:" ] + [f" { ln } " if ln else "" for ln in body ]
307+ body += [
308+ "" ,
309+ "except ImportError as err:" ,
310+ "" ,
311+ " from os import linesep" ,
312+ f" from { pkg } import __version__" ,
313+ f" msg = f'Your `lightning` package was built for `{ pkg } =={ ver } `," + " but you are running {__version__}'" ,
314+ " raise type(err)(str(err) + linesep + msg)" ,
315+ ]
316+ return body
317+
318+
319+ def parse_version_from_file (pkg_root : str ) -> str :
320+ """Loading the package version from file."""
321+ file_ver = os .path .join (pkg_root , "__version__.py" )
322+ file_about = os .path .join (pkg_root , "__about__.py" )
323+ if os .path .isfile (file_ver ):
324+ ver = _load_py_module ("version" , file_ver ).version
325+ elif os .path .isfile (file_about ):
326+ ver = _load_py_module ("about" , file_about ).__version__
327+ else : # this covers case you have build only meta-package so not additional source files are present
328+ ver = ""
329+ return ver
330+
331+
332+ def prune_duplicate_lines (body ):
333+ body_ = []
334+ # drop duplicated lines
335+ for ln in body :
336+ if ln .lstrip () not in body_ or ln .lstrip () in (")" , "" ):
337+ body_ .append (ln )
338+ return body_
339+
340+
265341def create_meta_package (src_folder : str , pkg_name : str = "pytorch_lightning" , lit_name : str = "pytorch" ):
266342 """Parse the real python package and for each module create a mirroe version with repalcing all function and
267343 class implementations by cross-imports to the true package.
@@ -271,6 +347,7 @@ class implementations by cross-imports to the true package.
271347 >>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
272348 """
273349 package_dir = os .path .join (src_folder , pkg_name )
350+ pkg_ver = parse_version_from_file (package_dir )
274351 # shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
275352 py_files = glob .glob (os .path .join (src_folder , pkg_name , "**" , "*.py" ), recursive = True )
276353 for py_file in py_files :
@@ -290,30 +367,99 @@ class implementations by cross-imports to the true package.
290367 logging .warning (f"unsupported file: { local_path } " )
291368 continue
292369 # ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
293- body = prune_comments_docstrings (lines )
370+ body = prune_comments_docstrings ([ ln . rstrip () for ln in lines ] )
294371 if fname not in ("__init__.py" , "__main__.py" ):
295372 body = prune_imports_callables (body )
296- body = replace_block_with_imports ([ln .rstrip () for ln in body ], import_path , "class" )
297- body = replace_block_with_imports (body , import_path , "def" )
298- body = replace_block_with_imports (body , import_path , "async def" )
373+ for key_word in ("class" , "def" , "async def" ):
374+ body = replace_block_with_imports (body , import_path , key_word )
375+ # TODO: fix reimporting which is artefact after replacing var assignment with import;
376+ # after fixing , update CI by remove F811 from CI/check pkg
299377 body = replace_vars_with_imports (body , import_path )
378+ if fname not in ("__main__.py" ,):
379+ body = prune_func_calls (body )
300380 body_len = - 1
301381 # in case of several in-depth statements
302382 while body_len != len (body ):
303383 body_len = len (body )
384+ body = prune_duplicate_lines (body )
304385 body = prune_empty_statements (body )
305- # TODO: add try/catch wrapper for whole body,
386+ # add try/catch wrapper for whole body,
306387 # so when import fails it tells you what is the package version this meta package was generated for...
388+ body = wrap_try_except (body , pkg_name , pkg_ver )
307389
308390 # todo: apply pre-commit formatting
391+ # clean to many empty lines
309392 body = [ln for ln , _group in groupby (body )]
310- lines = []
311393 # drop duplicated lines
312- for ln in body :
313- if ln + os .linesep not in lines or ln in (")" , "" ):
314- lines .append (ln + os .linesep )
394+ body = prune_duplicate_lines (body )
315395 # compose the target file name
316396 new_file = os .path .join (src_folder , "lightning" , lit_name , local_path )
317397 os .makedirs (os .path .dirname (new_file ), exist_ok = True )
318398 with open (new_file , "w" , encoding = "utf-8" ) as fp :
319- fp .writelines (lines )
399+ fp .writelines ([ln + os .linesep for ln in body ])
400+
401+
402+ def set_version_today (fpath : str ) -> None :
403+ """Replace the template date with today."""
404+ with open (fpath ) as fp :
405+ lines = fp .readlines ()
406+
407+ def _replace_today (ln ):
408+ today = datetime .now ()
409+ return ln .replace ("YYYY.-M.-D" , f"{ today .year } .{ today .month } .{ today .day } " )
410+
411+ lines = list (map (_replace_today , lines ))
412+ with open (fpath , "w" ) as fp :
413+ fp .writelines (lines )
414+
415+
416+ def _download_frontend (root : str = _PROJECT_ROOT ):
417+ """Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
418+ directory."""
419+
420+ try :
421+ frontend_dir = pathlib .Path (root , "src" , "lightning_app" , "ui" )
422+ download_dir = tempfile .mkdtemp ()
423+
424+ shutil .rmtree (frontend_dir , ignore_errors = True )
425+ response = urllib .request .urlopen (_LIGHTNING_FRONTEND_RELEASE_URL )
426+
427+ file = tarfile .open (fileobj = response , mode = "r|gz" )
428+ file .extractall (path = download_dir )
429+
430+ shutil .move (os .path .join (download_dir , "build" ), frontend_dir )
431+ print ("The Lightning UI has successfully been downloaded!" )
432+
433+ # If installing from source without internet connection, we don't want to break the installation
434+ except Exception :
435+ print ("The Lightning UI downloading has failed!" )
436+
437+
438+ def _adjust_require_versions (source_dir : str = "src" , req_dir : str = "requirements" ) -> None :
439+ """Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`."""
440+ reqs = load_requirements (req_dir , file_name = "base.txt" )
441+ for i , req in enumerate (reqs ):
442+ pkg_name = req [: min (req .index (c ) for c in ">=" if c in req )]
443+ ver_ = parse_version_from_file (os .path .join (source_dir , pkg_name ))
444+ if not ver_ :
445+ continue
446+ ver2 = "." .join (ver_ .split ("." )[:2 ] + ["*" ])
447+ reqs [i ] = f"{ req } , =={ ver2 } "
448+
449+ with open (os .path .join (req_dir , "base.txt" ), "w" ) as fp :
450+ fp .writelines ([ln + os .linesep for ln in reqs ])
451+
452+
453+ def _load_aggregate_requirements (req_dir : str = "requirements" , freeze_requirements : bool = False ) -> None :
454+ """Load all base requirements from all particular packages and prune duplicates."""
455+ requires = [
456+ load_requirements (d , file_name = "base.txt" , unfreeze = not freeze_requirements )
457+ for d in glob .glob (os .path .join (req_dir , "*" ))
458+ if os .path .isdir (d )
459+ ]
460+ if not requires :
461+ return None
462+ # TODO: add some smarter version aggregation per each package
463+ requires = list (chain (* requires ))
464+ with open (os .path .join (req_dir , "base.txt" ), "w" ) as fp :
465+ fp .writelines ([ln + os .linesep for ln in requires ])
0 commit comments