Skip to content
108 changes: 80 additions & 28 deletions .actions/setup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import tempfile
import urllib.request
from datetime import datetime
from distutils.version import LooseVersion
from importlib.util import module_from_spec, spec_from_file_location
from itertools import chain, groupby
from types import ModuleType
from typing import List

from pkg_resources import parse_requirements

_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
_PACKAGE_MAPPING = {"pytorch": "pytorch_lightning", "app": "lightning_app"}

Expand All @@ -42,45 +45,92 @@ def _load_py_module(name: str, location: str) -> ModuleType:
return py


def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: str = "all") -> str:
"""Adjust the upper version contrains.

Args:
ln: raw line from requirement
comment_char: charter marking comment
unfreeze: Enum or "all"|"major"|""

Returns:
adjusted requirement

>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # anything", unfreeze="")
'arrow>=1.2.0, <=1.2.2'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="")
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # my name", unfreeze="all")
'arrow>=1.2.0'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="all")
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow", unfreeze="all")
'arrow'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # cool", unfreeze="major")
'arrow>=1.2.0, <2.0 # strict'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="major")
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow>=1.2.0", unfreeze="major")
'arrow>=1.2.0, <2.0 # strict'
>>> _augment_requirement("arrow", unfreeze="major")
'arrow'
"""
# filer all comments
if comment_char in ln:
comment = ln[ln.index(comment_char) :]
ln = ln[: ln.index(comment_char)]
is_strict = "strict" in comment
else:
is_strict = False
req = ln.strip()
# skip directly installed dependencies
if not req or req.startswith("http") or "@http" in req:
return ""
# extract the major version from all listed versions
if unfreeze == "major":
req_ = list(parse_requirements([req]))[0]
vers = [LooseVersion(v) for s, v in req_.specs if s not in ("==", "~=")]
ver_major = sorted(vers)[-1].version[0] if vers else None
else:
ver_major = None

# remove version restrictions unless they are strict
if unfreeze and "<" in req and not is_strict:
req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip()
if ver_major is not None and not is_strict:
# add , only if there are already some versions
req += f"{',' if any(c in req for c in '<=>') else ''} <{int(ver_major) + 1}.0"

# adding strict back to the comment
if is_strict or ver_major is not None:
req += " # strict"

return req


def load_requirements(
path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True
path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: str = "all"
) -> List[str]:
"""Loading requirements from a file.

>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
>>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['numpy...', 'torch...', ...]
>>> load_requirements(path_req, unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['pytorch_lightning...', 'lightning_app...']
"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:
# filer all comments
comment = ""
if comment_char in ln:
comment = ln[ln.index(comment_char) :]
ln = ln[: ln.index(comment_char)]
req = ln.strip()
# skip directly installed dependencies
if not req or req.startswith("http") or "@http" in req:
continue
# remove version restrictions unless they are strict
if unfreeze and "<" in req and "strict" not in comment:
req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip()

# adding strict back to the comment
if "strict" in comment:
req += " # strict"

reqs.append(req)
return reqs
reqs.append(_augment_requirement(ln, comment_char=comment_char, unfreeze=unfreeze))
# filter empty lines
return [str(req) for req in reqs if req]


def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
"""Load readme as decribtion.

>>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'<div align="center">...'
'...'
"""
path_readme = os.path.join(path_dir, "README.md")
text = open(path_readme, encoding="utf-8").read()
Expand Down Expand Up @@ -439,12 +489,14 @@ def _download_frontend(root: str = _PROJECT_ROOT):
print("The Lightning UI downloading has failed!")


def _adjust_require_versions(source_dir: str = "src", req_dir: str = "requirements") -> None:
"""Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`."""
def _relax_require_versions(source_dir: str = "src", req_dir: str = "requirements") -> None:
"""Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`.

>>> _relax_require_versions("../src", "../requirements")
"""
reqs = load_requirements(req_dir, file_name="base.txt")
for i, req in enumerate(reqs):
pkg_name = req[: min(req.index(c) for c in ">=" if c in req)]
ver_ = parse_version_from_file(os.path.join(source_dir, pkg_name))
for i, req in enumerate(parse_requirements(reqs)):
ver_ = parse_version_from_file(os.path.join(source_dir, req.name))
if not ver_:
continue
ver2 = ".".join(ver_.split(".")[:2] + ["*"])
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
_SETUP_TOOLS = _load_py_module(name="setup_tools", location=os.path.join(".actions", "setup_tools.py"))

if _PACKAGE_NAME == "lightning": # install just the meta package
_SETUP_TOOLS._adjust_require_versions(_PATH_SRC, _PATH_REQUIRE)
_SETUP_TOOLS._relax_require_versions(_PATH_SRC, _PATH_REQUIRE)
elif _PACKAGE_NAME not in _PACKAGE_MAPPING: # install everything
_SETUP_TOOLS._load_aggregate_requirements(_PATH_REQUIRE, _FREEZE_REQUIREMENTS)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _setup_args(**kwargs: Any) -> Dict[str, Any]:
],
},
setup_requires=[],
install_requires=_SETUP_TOOLS.load_requirements(_PATH_REQUIREMENTS, unfreeze=True),
install_requires=_SETUP_TOOLS.load_requirements(_PATH_REQUIREMENTS, unfreeze="all"),
extras_require={}, # todo: consider porting all other packages extras with prefix
project_urls={
"Bug Tracker": "https://github.com/Lightning-AI/lightning/issues",
Expand Down
6 changes: 4 additions & 2 deletions src/lightning_app/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _prepare_extras(**kwargs: Any) -> Dict[str, Any]:
# Define package extras. These are only installed if you specify them.
# From remote, use like `pip install pytorch-lightning[dev, docs]`
# From local copy of repo, use like `pip install ".[dev, docs]"`
common_args = dict(path_dir=_PATH_REQUIREMENTS, unfreeze=not _FREEZE_REQUIREMENTS)
common_args = dict(path_dir=_PATH_REQUIREMENTS, unfreeze="major" if _FREEZE_REQUIREMENTS else "all")
extras = {
# 'docs': load_requirements(file_name='docs.txt'),
"cloud": _setup_tools.load_requirements(file_name="cloud.txt", **common_args),
Expand Down Expand Up @@ -95,7 +95,9 @@ def _setup_args(**__: Any) -> Dict[str, Any]:
],
},
setup_requires=["wheel"],
install_requires=_setup_tools.load_requirements(_PATH_REQUIREMENTS, unfreeze=not _FREEZE_REQUIREMENTS),
install_requires=_setup_tools.load_requirements(
_PATH_REQUIREMENTS, unfreeze="major" if _FREEZE_REQUIREMENTS else "all"
),
extras_require=_prepare_extras(),
project_urls={
"Bug Tracker": "https://github.com/Lightning-AI/lightning/issues",
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_lightning/__setup__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _prepare_extras(**kwargs: Any) -> Dict[str, Any]:
# Define package extras. These are only installed if you specify them.
# From remote, use like `pip install pytorch-lightning[dev, docs]`
# From local copy of repo, use like `pip install ".[dev, docs]"`
common_args = dict(path_dir=_PATH_REQUIREMENTS, unfreeze=not _FREEZE_REQUIREMENTS)
common_args = dict(path_dir=_PATH_REQUIREMENTS, unfreeze="" if _FREEZE_REQUIREMENTS else "all")
extras = {
# 'docs': load_requirements(file_name='docs.txt'),
"examples": _setup_tools.load_requirements(file_name="examples.txt", **common_args),
Expand Down Expand Up @@ -99,7 +99,9 @@ def _setup_args(**__: Any) -> Dict[str, Any]:
keywords=["deep learning", "pytorch", "AI"],
python_requires=">=3.7",
setup_requires=[],
install_requires=_setup_tools.load_requirements(_PATH_REQUIREMENTS, unfreeze=not _FREEZE_REQUIREMENTS),
install_requires=_setup_tools.load_requirements(
_PATH_REQUIREMENTS, unfreeze="" if _FREEZE_REQUIREMENTS else "all"
),
extras_require=_prepare_extras(),
project_urls={
"Bug Tracker": "https://github.com/Lightning-AI/lightning/issues",
Expand Down