Skip to content

Commit 1b77c58

Browse files
Merge branch 'master' into mypy_quant
2 parents e534e9c + b3203d9 commit 1b77c58

File tree

343 files changed

+9287
-5273
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

343 files changed

+9287
-5273
lines changed

.actions/setup_tools.py

Lines changed: 159 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,24 @@
1414
import glob
1515
import logging
1616
import os
17+
import pathlib
1718
import re
19+
import shutil
20+
import tarfile
21+
import tempfile
22+
import urllib.request
23+
from datetime import datetime
1824
from importlib.util import module_from_spec, spec_from_file_location
19-
from itertools import groupby
25+
from itertools import chain, groupby
2026
from types import ModuleType
2127
from typing import List
2228

2329
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
2430
_PACKAGE_MAPPING = {"pytorch": "pytorch_lightning", "app": "lightning_app"}
2531

32+
# TODO: remove this once lightning-ui package is ready as a dependency
33+
_LIGHTNING_FRONTEND_RELEASE_URL = "https://storage.googleapis.com/grid-packages/lightning-ui/v0.0.0/build.tar.gz"
34+
2635

2736
def _load_py_module(name: str, location: str) -> ModuleType:
2837
spec = spec_from_file_location(name, location)
@@ -36,7 +45,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
3645
def load_requirements(
3746
path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True
3847
) -> List[str]:
39-
"""Load requirements from a file.
48+
"""Loading requirements from a file.
4049
4150
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
4251
>>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
@@ -142,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
142151
... lines = [ln.rstrip() for ln in fp.readlines()]
143152
>>> lines = replace_vars_with_imports(lines, import_path)
144153
"""
154+
copied = []
145155
body, tracking, skip_offset = [], False, 0
146156
for ln in lines:
147157
offset = len(ln) - len(ln.lstrip())
@@ -152,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
152162
if var:
153163
name = var.groups()[0]
154164
# skip private or apply white-list for allowed vars
155-
if not name.startswith("__") or name in ("__all__",):
165+
if name not in copied and (not name.startswith("__") or name in ("__all__",)):
156166
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
167+
copied.append(name)
157168
tracking, skip_offset = True, offset
158169
continue
159170
if not tracking:
@@ -188,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
188199
return body
189200

190201

202+
def prune_func_calls(lines: List[str]) -> List[str]:
203+
"""Prune calling functions from a file, even multi-line.
204+
205+
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
206+
>>> import_path = ".".join(["pytorch_lightning", "loggers"])
207+
>>> with open(py_file, encoding="utf-8") as fp:
208+
... lines = [ln.rstrip() for ln in fp.readlines()]
209+
>>> lines = prune_func_calls(lines)
210+
"""
211+
body, tracking, score = [], False, 0
212+
for ln in lines:
213+
# catching callable
214+
calling = re.match(r"^@?[\w_\d\.]+ *\(", ln.lstrip())
215+
if calling and " import " not in ln:
216+
tracking = True
217+
score = 0
218+
if tracking:
219+
score += ln.count("(") - ln.count(")")
220+
if score == 0:
221+
tracking = False
222+
else:
223+
body.append(ln)
224+
return body
225+
226+
191227
def prune_empty_statements(lines: List[str]) -> List[str]:
192228
"""Prune emprty if/else and try/except.
193229
@@ -262,6 +298,46 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
262298
return body
263299

264300

301+
def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]:
302+
"""Wrap the file with try/except for better traceability of import misalignment."""
303+
not_empty = sum(1 for ln in body if ln)
304+
if not_empty == 0:
305+
return body
306+
body = ["try:"] + [f" {ln}" if ln else "" for ln in body]
307+
body += [
308+
"",
309+
"except ImportError as err:",
310+
"",
311+
" from os import linesep",
312+
f" from {pkg} import __version__",
313+
f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'",
314+
" raise type(err)(str(err) + linesep + msg)",
315+
]
316+
return body
317+
318+
319+
def parse_version_from_file(pkg_root: str) -> str:
320+
"""Loading the package version from file."""
321+
file_ver = os.path.join(pkg_root, "__version__.py")
322+
file_about = os.path.join(pkg_root, "__about__.py")
323+
if os.path.isfile(file_ver):
324+
ver = _load_py_module("version", file_ver).version
325+
elif os.path.isfile(file_about):
326+
ver = _load_py_module("about", file_about).__version__
327+
else: # this covers case you have build only meta-package so not additional source files are present
328+
ver = ""
329+
return ver
330+
331+
332+
def prune_duplicate_lines(body):
333+
body_ = []
334+
# drop duplicated lines
335+
for ln in body:
336+
if ln.lstrip() not in body_ or ln.lstrip() in (")", ""):
337+
body_.append(ln)
338+
return body_
339+
340+
265341
def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
266342
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
267343
class implementations by cross-imports to the true package.
@@ -271,6 +347,7 @@ class implementations by cross-imports to the true package.
271347
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
272348
"""
273349
package_dir = os.path.join(src_folder, pkg_name)
350+
pkg_ver = parse_version_from_file(package_dir)
274351
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
275352
py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True)
276353
for py_file in py_files:
@@ -290,30 +367,99 @@ class implementations by cross-imports to the true package.
290367
logging.warning(f"unsupported file: {local_path}")
291368
continue
292369
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
293-
body = prune_comments_docstrings(lines)
370+
body = prune_comments_docstrings([ln.rstrip() for ln in lines])
294371
if fname not in ("__init__.py", "__main__.py"):
295372
body = prune_imports_callables(body)
296-
body = replace_block_with_imports([ln.rstrip() for ln in body], import_path, "class")
297-
body = replace_block_with_imports(body, import_path, "def")
298-
body = replace_block_with_imports(body, import_path, "async def")
373+
for key_word in ("class", "def", "async def"):
374+
body = replace_block_with_imports(body, import_path, key_word)
375+
# TODO: fix reimporting which is artefact after replacing var assignment with import;
376+
# after fixing , update CI by remove F811 from CI/check pkg
299377
body = replace_vars_with_imports(body, import_path)
378+
if fname not in ("__main__.py",):
379+
body = prune_func_calls(body)
300380
body_len = -1
301381
# in case of several in-depth statements
302382
while body_len != len(body):
303383
body_len = len(body)
384+
body = prune_duplicate_lines(body)
304385
body = prune_empty_statements(body)
305-
# TODO: add try/catch wrapper for whole body,
386+
# add try/catch wrapper for whole body,
306387
# so when import fails it tells you what is the package version this meta package was generated for...
388+
body = wrap_try_except(body, pkg_name, pkg_ver)
307389

308390
# todo: apply pre-commit formatting
391+
# clean to many empty lines
309392
body = [ln for ln, _group in groupby(body)]
310-
lines = []
311393
# drop duplicated lines
312-
for ln in body:
313-
if ln + os.linesep not in lines or ln in (")", ""):
314-
lines.append(ln + os.linesep)
394+
body = prune_duplicate_lines(body)
315395
# compose the target file name
316396
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)
317397
os.makedirs(os.path.dirname(new_file), exist_ok=True)
318398
with open(new_file, "w", encoding="utf-8") as fp:
319-
fp.writelines(lines)
399+
fp.writelines([ln + os.linesep for ln in body])
400+
401+
402+
def set_version_today(fpath: str) -> None:
403+
"""Replace the template date with today."""
404+
with open(fpath) as fp:
405+
lines = fp.readlines()
406+
407+
def _replace_today(ln):
408+
today = datetime.now()
409+
return ln.replace("YYYY.-M.-D", f"{today.year}.{today.month}.{today.day}")
410+
411+
lines = list(map(_replace_today, lines))
412+
with open(fpath, "w") as fp:
413+
fp.writelines(lines)
414+
415+
416+
def _download_frontend(root: str = _PROJECT_ROOT):
417+
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
418+
directory."""
419+
420+
try:
421+
frontend_dir = pathlib.Path(root, "src", "lightning_app", "ui")
422+
download_dir = tempfile.mkdtemp()
423+
424+
shutil.rmtree(frontend_dir, ignore_errors=True)
425+
response = urllib.request.urlopen(_LIGHTNING_FRONTEND_RELEASE_URL)
426+
427+
file = tarfile.open(fileobj=response, mode="r|gz")
428+
file.extractall(path=download_dir)
429+
430+
shutil.move(os.path.join(download_dir, "build"), frontend_dir)
431+
print("The Lightning UI has successfully been downloaded!")
432+
433+
# If installing from source without internet connection, we don't want to break the installation
434+
except Exception:
435+
print("The Lightning UI downloading has failed!")
436+
437+
438+
def _adjust_require_versions(source_dir: str = "src", req_dir: str = "requirements") -> None:
439+
"""Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`."""
440+
reqs = load_requirements(req_dir, file_name="base.txt")
441+
for i, req in enumerate(reqs):
442+
pkg_name = req[: min(req.index(c) for c in ">=" if c in req)]
443+
ver_ = parse_version_from_file(os.path.join(source_dir, pkg_name))
444+
if not ver_:
445+
continue
446+
ver2 = ".".join(ver_.split(".")[:2] + ["*"])
447+
reqs[i] = f"{req}, =={ver2}"
448+
449+
with open(os.path.join(req_dir, "base.txt"), "w") as fp:
450+
fp.writelines([ln + os.linesep for ln in reqs])
451+
452+
453+
def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requirements: bool = False) -> None:
454+
"""Load all base requirements from all particular packages and prune duplicates."""
455+
requires = [
456+
load_requirements(d, file_name="base.txt", unfreeze=not freeze_requirements)
457+
for d in glob.glob(os.path.join(req_dir, "*"))
458+
if os.path.isdir(d)
459+
]
460+
if not requires:
461+
return None
462+
# TODO: add some smarter version aggregation per each package
463+
requires = list(chain(*requires))
464+
with open(os.path.join(req_dir, "base.txt"), "w") as fp:
465+
fp.writelines([ln + os.linesep for ln in requires])

.azure/gpu-benchmark.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
cancelTimeoutInMinutes: "2"
2929
pool: azure-jirka-spot
3030
container:
31-
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11"
31+
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12"
3232
options: "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all --shm-size=32g"
3333
workspace:
3434
clean: all

.azure/gpu-tests.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
strategy:
2727
matrix:
2828
'PyTorch - stable':
29-
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.11"
29+
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.9-torch1.12"
3030
# how long to run the job before automatically cancelling
3131
timeoutInMinutes: "80"
3232
# how much time to give 'run always even if cancelled tasks' before stopping them
@@ -44,7 +44,7 @@ jobs:
4444

4545
- bash: |
4646
CHANGED_FILES=$(git diff --name-status origin/master -- . | awk '{print $2}')
47-
FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*'
47+
FILTER='src/pytorch_lightning|requirements/pytorch|tests/tests_pytorch|examples/pl_*|.azure/*'
4848
echo $CHANGED_FILES > changed_files.txt
4949
MATCHES=$(cat changed_files.txt | grep -E $FILTER)
5050
echo $MATCHES
@@ -69,10 +69,13 @@ jobs:
6969
condition: eq(variables['continue'], '1')
7070
7171
- bash: |
72+
set -e
7273
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
74+
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'bagua' not in line] ; open(fname, 'w').writelines(lines)"
7375
CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
7476
pip install "bagua-cuda$CUDA_VERSION_MM>=0.9.0"
7577
pip install -e .[strategies]
78+
pip install deepspeed>0.6.4 # TODO: remove when docker images are upgraded
7679
pip install --requirement requirements/pytorch/devel.txt
7780
pip list
7881
env:
@@ -116,6 +119,15 @@ jobs:
116119
timeoutInMinutes: "35"
117120
condition: eq(variables['continue'], '1')
118121

122+
- bash: bash run_standalone_tasks.sh
123+
workingDirectory: tests/tests_pytorch
124+
env:
125+
PL_USE_MOCKED_MNIST: "1"
126+
PL_RUN_CUDA_TESTS: "1"
127+
displayName: 'Testing: PyTorch standalone tasks'
128+
timeoutInMinutes: "10"
129+
condition: eq(variables['continue'], '1')
130+
119131
- bash: |
120132
python -m coverage report
121133
python -m coverage xml

0 commit comments

Comments
 (0)