Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytorch_lightning/setup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def _load_requirements(
reqs = []
for ln in lines:
# filer all comments
comment = ""
if comment_char in ln:
comment = ln[ln.index(comment_char) :]
ln = ln[: ln.index(comment_char)]
comment = ln[ln.index(comment_char) :] if comment_char in ln else ""
req = ln.strip()
# skip directly installed dependencies
if not req or req.startswith("http") or "@http" in req:
Expand Down
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ torchmetrics>=0.4.1, <=0.7.2
pyDeprecate>=0.3.1, <=0.3.2
packaging>=17.0, <=21.3
typing-extensions>=4.0.0, <4.2.1
protobuf<=3.20.1 # strict. TODO: Remove after tensorboard gets compatible https://github.com/tensorflow/tensorboard/issues/5708
2 changes: 1 addition & 1 deletion requirements/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fairscale>=0.4.5, <=0.4.6
deepspeed<0.7.0
deepspeed<0.6.0
# no need to install with [pytorch] as pytorch is already installed
horovod>=0.21.2,!=0.24.0, <=0.24.3
hivemind>=1.0.1, <=1.0.1; sys_platform == 'linux'
8 changes: 7 additions & 1 deletion tests/callbacks/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional import mean_absolute_percentage_error as mape

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.callbacks import QuantizationAwareTraining
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import get_model_size_mb
Expand All @@ -35,9 +36,14 @@
@RunIf(quantization=True)
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
"""Parity test for quant model."""
cuda_available = GPUAccelerator.is_available()

if observe == "average" and not fuse and GPUAccelerator.is_available():
pytest.xfail("TODO: flakiness in GPU CI")

seed_everything(42)
dm = RegressDataModule()
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
accelerator = "gpu" if cuda_available else "cpu"
trainer_args = dict(default_root_dir=tmpdir, max_epochs=7, accelerator=accelerator, devices=1)
model = RegressionModel()
qmodel = copy.deepcopy(model)
Expand Down
6 changes: 3 additions & 3 deletions tests/standalone_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ files=$(echo "$grep_output" | cut -f1 -d: | sort | uniq)
# get the list of parametrizations. we need to call them separately. the last two lines are removed.
# note: if there's a syntax error, this will fail with some garbled output
if [[ "$OSTYPE" == "darwin"* ]]; then
parametrizations=$(pytest $files --collect-only --quiet "$@" | tail -r | sed -e '1,3d' | tail -r)
parametrizations=$(python -m pytest $files --collect-only --quiet "$@" | tail -r | sed -e '1,3d' | tail -r)
else
parametrizations=$(pytest $files --collect-only --quiet "$@" | head -n -2)
parametrizations=$(python -m pytest $files --collect-only --quiet "$@" | head -n -2)
fi
parametrizations_arr=($parametrizations)

# tests to skip - space separated
blocklist='tests/profiler/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx'
blocklist='tests/profiler/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx tests/utilities/test_warnings.py'
report=''

for i in "${!parametrizations_arr[@]}"; do
Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.utilities.warnings import WarningCache

standalone = os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1"
if standalone:
if standalone and __name__ == "__main__":

stderr = StringIO()
# recording
Expand Down