2020import tarfile
2121import tempfile
2222import urllib .request
23+ from datetime import datetime
2324from importlib .util import module_from_spec , spec_from_file_location
2425from itertools import groupby
2526from types import ModuleType
@@ -150,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
150151 ... lines = [ln.rstrip() for ln in fp.readlines()]
151152 >>> lines = replace_vars_with_imports(lines, import_path)
152153 """
154+ copied = []
153155 body , tracking , skip_offset = [], False , 0
154156 for ln in lines :
155157 offset = len (ln ) - len (ln .lstrip ())
@@ -160,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
160162 if var :
161163 name = var .groups ()[0 ]
162164 # skip private or apply white-list for allowed vars
163- if not name .startswith ("__" ) or name in ("__all__" ,):
165+ if name not in copied and ( not name .startswith ("__" ) or name in ("__all__" ,) ):
164166 body .append (f"{ ' ' * offset } from { import_path } import { name } # noqa: F401" )
167+ copied .append (name )
165168 tracking , skip_offset = True , offset
166169 continue
167170 if not tracking :
@@ -196,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
196199 return body
197200
198201
202+ def prune_func_calls (lines : List [str ]) -> List [str ]:
203+ """Prune calling functions from a file, even multi-line.
204+
205+ >>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
206+ >>> import_path = ".".join(["pytorch_lightning", "loggers"])
207+ >>> with open(py_file, encoding="utf-8") as fp:
208+ ... lines = [ln.rstrip() for ln in fp.readlines()]
209+ >>> lines = prune_func_calls(lines)
210+ """
211+ body , tracking , score = [], False , 0
212+ for ln in lines :
213+ # catching callable
214+ calling = re .match (r"^@?[\w_\d\.]+ *\(" , ln .lstrip ())
215+ if calling and " import " not in ln :
216+ tracking = True
217+ score = 0
218+ if tracking :
219+ score += ln .count ("(" ) - ln .count (")" )
220+ if score == 0 :
221+ tracking = False
222+ else :
223+ body .append (ln )
224+ return body
225+
226+
199227def prune_empty_statements (lines : List [str ]) -> List [str ]:
200228 """Prune emprty if/else and try/except.
201229
@@ -270,6 +298,46 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
270298 return body
271299
272300
301+ def wrap_try_except (body : List [str ], pkg : str , ver : str ) -> List [str ]:
302+ """Wrap the file with try/except for better traceability of import misalignment."""
303+ not_empty = sum (1 for ln in body if ln )
304+ if not_empty == 0 :
305+ return body
306+ body = ["try:" ] + [f" { ln } " if ln else "" for ln in body ]
307+ body += [
308+ "" ,
309+ "except ImportError as err:" ,
310+ "" ,
311+ " from os import linesep" ,
312+ f" from { pkg } import __version__" ,
313+ f" msg = f'Your `lightning` package was built for `{ pkg } =={ ver } `," + " but you are running {__version__}'" ,
314+ " raise type(err)(str(err) + linesep + msg)" ,
315+ ]
316+ return body
317+
318+
319+ def parse_version_from_file (pkg_root : str ) -> str :
320+ """Loading the package version from file."""
321+ file_ver = os .path .join (pkg_root , "__version__.py" )
322+ file_about = os .path .join (pkg_root , "__about__.py" )
323+ if os .path .isfile (file_ver ):
324+ ver = _load_py_module ("version" , file_ver ).version
325+ elif os .path .isfile (file_about ):
326+ ver = _load_py_module ("about" , file_about ).__version__
327+ else : # this covers case you have build only meta-package so not additional source files are present
328+ ver = ""
329+ return ver
330+
331+
332+ def prune_duplicate_lines (body ):
333+ body_ = []
334+ # drop duplicated lines
335+ for ln in body :
336+ if ln .lstrip () not in body_ or ln .lstrip () in (")" , "" ):
337+ body_ .append (ln )
338+ return body_
339+
340+
273341def create_meta_package (src_folder : str , pkg_name : str = "pytorch_lightning" , lit_name : str = "pytorch" ):
274342 """Parse the real python package and for each module create a mirroe version with repalcing all function and
275343 class implementations by cross-imports to the true package.
@@ -279,6 +347,7 @@ class implementations by cross-imports to the true package.
279347 >>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
280348 """
281349 package_dir = os .path .join (src_folder , pkg_name )
350+ pkg_ver = parse_version_from_file (package_dir )
282351 # shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
283352 py_files = glob .glob (os .path .join (src_folder , pkg_name , "**" , "*.py" ), recursive = True )
284353 for py_file in py_files :
@@ -298,41 +367,57 @@ class implementations by cross-imports to the true package.
298367 logging .warning (f"unsupported file: { local_path } " )
299368 continue
300369 # ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
301- body = prune_comments_docstrings (lines )
370+ body = prune_comments_docstrings ([ ln . rstrip () for ln in lines ] )
302371 if fname not in ("__init__.py" , "__main__.py" ):
303372 body = prune_imports_callables (body )
304- body = replace_block_with_imports ([ln .rstrip () for ln in body ], import_path , "class" )
305- body = replace_block_with_imports (body , import_path , "def" )
306- body = replace_block_with_imports (body , import_path , "async def" )
373+ for key_word in ("class" , "def" , "async def" ):
374+ body = replace_block_with_imports (body , import_path , key_word )
375+ # TODO: fix reimporting which is artefact after replacing var assignment with import;
376+ # after fixing , update CI by remove F811 from CI/check pkg
307377 body = replace_vars_with_imports (body , import_path )
378+ if fname not in ("__main__.py" ,):
379+ body = prune_func_calls (body )
308380 body_len = - 1
309381 # in case of several in-depth statements
310382 while body_len != len (body ):
311383 body_len = len (body )
384+ body = prune_duplicate_lines (body )
312385 body = prune_empty_statements (body )
313- # TODO: add try/catch wrapper for whole body,
386+ # add try/catch wrapper for whole body,
314387 # so when import fails it tells you what is the package version this meta package was generated for...
388+ body = wrap_try_except (body , pkg_name , pkg_ver )
315389
316390 # todo: apply pre-commit formatting
391+ # clean to many empty lines
317392 body = [ln for ln , _group in groupby (body )]
318- lines = []
319393 # drop duplicated lines
320- for ln in body :
321- if ln + os .linesep not in lines or ln in (")" , "" ):
322- lines .append (ln + os .linesep )
394+ body = prune_duplicate_lines (body )
323395 # compose the target file name
324396 new_file = os .path .join (src_folder , "lightning" , lit_name , local_path )
325397 os .makedirs (os .path .dirname (new_file ), exist_ok = True )
326398 with open (new_file , "w" , encoding = "utf-8" ) as fp :
327- fp .writelines (lines )
399+ fp .writelines ([ln + os .linesep for ln in body ])
400+
401+
402+ def set_version_today (fpath : str ) -> None :
403+ """Replace the template date with today."""
404+ with open (fpath ) as fp :
405+ lines = fp .readlines ()
406+
407+ def _replace_today (ln ):
408+ today = datetime .now ()
409+ return ln .replace ("YYYY.-M.-D" , f"{ today .year } .{ today .month } .{ today .day } " )
410+
411+ lines = list (map (_replace_today , lines ))
412+ with open (fpath , "w" ) as fp :
413+ fp .writelines (lines )
328414
329415
330416def _download_frontend (root : str = _PROJECT_ROOT ):
331417 """Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
332418 directory."""
333419
334420 try :
335- build_dir = "build"
336421 frontend_dir = pathlib .Path (root , "src" , "lightning_app" , "ui" )
337422 download_dir = tempfile .mkdtemp ()
338423
@@ -342,7 +427,7 @@ def _download_frontend(root: str = _PROJECT_ROOT):
342427 file = tarfile .open (fileobj = response , mode = "r|gz" )
343428 file .extractall (path = download_dir )
344429
345- shutil .move (os .path .join (download_dir , build_dir ), frontend_dir )
430+ shutil .move (os .path .join (download_dir , "build" ), frontend_dir )
346431 print ("The Lightning UI has successfully been downloaded!" )
347432
348433 # If installing from source without internet connection, we don't want to break the installation
0 commit comments