diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..d4d02071 --- /dev/null +++ b/.clang-format @@ -0,0 +1,88 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 2000000 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 00000000..cfef0358 --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,22 @@ +version = 1 + +test_patterns = ["tests/**"] + +[[analyzers]] +name = "python" +enabled = true + + [analyzers.meta] + runtime_version = "3.x.x" + +[[analyzers]] +name = "test-coverage" +enabled = true + +[[analyzers]] +name = "docker" +enabled = true + +[[analyzers]] +name = "shell" +enabled = true diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml new file mode 100644 index 00000000..0d629c29 --- /dev/null +++ b/.github/workflows/pythonapp-min.yml @@ -0,0 +1,171 @@ +# Jenkinsfile.monai-premerge +name: premerge-min + +on: + # quick tests for pull requests and the releasing branches + push: + branches: + - main + pull_request: + +concurrency: + # automatically cancel the previously triggered workflows when there's a newer version + group: build-min-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + # caching of these jobs: + # - docker-py3-pip- (shared) + # - ubuntu py37 pip- + # - os-latest-pip- (shared) + min-dep-os: # min dependencies installed tests for different OS + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [windows-latest, macOS-latest, ubuntu-latest] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Prepare pip wheel + run: | + which python + python -m pip install --upgrade pip wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + shell: bash + - name: cache for pip + uses: actions/cache@v3 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }} + - if: runner.os == 'windows' + name: Install torch cpu from pytorch.org (Windows only) + run: | + python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install the dependencies + run: | + # min. requirements + python -m pip install torch==1.13.1 + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (CPU ${{ runner.os }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + shell: bash + env: + QUICKTEST: True + + min-dep-py3: # min dependencies installed tests for different python + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Prepare pip wheel + run: | + which python + python -m pip install --user --upgrade pip setuptools wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + shell: bash + - name: cache for pip + uses: actions/cache@v3 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install the dependencies + run: | + # min. requirements + python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (CPU ${{ runner.os }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + env: + QUICKTEST: True + + min-dep-pytorch: # min dependencies installed tests for different pytorch + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest'] + timeout-minutes: 40 + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Prepare pip wheel + run: | + which python + python -m pip install --user --upgrade pip setuptools wheel + - name: cache weekly timestamp + id: pip-cache + run: | + echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + shell: bash + - name: cache for pip + uses: actions/cache@v3 + id: cache + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ubuntu-latest-latest-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install the dependencies + run: | + # min. requirements + if [ ${{ matrix.pytorch-version }} == "latest" ]; then + python -m pip install torch + elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then + python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu + elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then + python -m pip install torch==1.9.1 + elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then + python -m pip install torch==1.10.2 + elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then + python -m pip install torch==1.11.0 + elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then + python -m pip install torch==1.12.1 + fi + python -m pip install -r requirements-min.txt + python -m pip list + BUILD_MONAI=0 python setup.py develop # no compile of extensions + shell: bash + - name: Run quick tests (pytorch ${{ matrix.pytorch-version }}) + run: | + python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' + python -c "import monai; monai.config.print_config()" + ./runtests.sh --min + env: + QUICKTEST: True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80998b97..48cea289 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,15 @@ +default_language_version: + python: python3.8 + +ci: + autofix_prs: true + autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' + autoupdate_schedule: quarterly + # submodules: true + repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -17,24 +26,49 @@ repos: args: ['--autofix', '--no-sort-keys', '--indent=4'] - id: end-of-file-fixer - id: mixed-line-ending + + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: [--py37-plus] + name: Upgrade code + exclude: | + (?x)^( + versioneer.py| + monai/_version.py + )$ + + - repo: https://github.com/asottile/yesqa + rev: v1.4.0 + hooks: + - id: yesqa + name: Unused noqa + additional_dependencies: + - flake8>=3.8.1 + - flake8-bugbear + - flake8-comprehensions + - flake8-executable + - flake8-pyi + - pep8-naming + exclude: | + (?x)^( + generative/__init__.py| + docs/source/conf.py + )$ + + - repo: https://github.com/hadialqattan/pycln + rev: v2.1.2 + hooks: + - id: pycln + args: [--config=pyproject.toml] + - repo: https://github.com/psf/black rev: 22.3.0 hooks: - id: black + - repo: https://github.com/PyCQA/isort rev: 5.9.3 hooks: - id: isort - - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 - hooks: - - id: flake8 - args: # arguments to configure flake8 - # these are errors that will be ignored by flake8 - # check out their meaning here - # https://flake8.pycqa.org/en/latest/user/error-codes.html - - "--ignore=E203,E266,E501,W503,E731,F541,F841" - # Adding args to work with black format - - "--max-line-length=120" - - "--max-complexity=18" - - "--per-file-ignores=__init__.py:F401" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bc9ff7dc..edf45d4c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -71,27 +71,38 @@ If you intend for any variables/functions/classes to be available outside of the - Add to the `__init__.py` file. #### Unit testing ->In progress. Please wait for more instructions to follow - MONAI Generative Models tests are located under `tests/`. -##### Set environment -To use the tests already available at MONAI core, first we clone it: -```shell -git clone https://github.com/Project-MONAI/MONAI --branch main -``` +- The unit test's file name currently follows `test_[module_name].py` or `test_[module_name]_dist.py`. +- The `test_[module_name]_dist.py` subset of unit tests requires a distributed environment to verify the module with distributed GPU-based computation. +- The integration test's file name follows `test_integration_[workflow_name].py`. -Then we add it to PYTHONPATH -```shell -export PYTHONPATH="${PYTHONPATH}:./MONAI/" +A bash script (`runtests.sh`) is provided to run all tests locally. +Please run ``./runtests.sh -h`` to see all options. + +To run a particular test, for example `tests/test_spectral_loss.py`: +``` +python -m tests.test_spectral_loss ``` -##### Executing tests -To run tests, use the following command: +Before submitting a pull request, we recommend that all linting and unit tests +should pass, by running the following command locally: -```shell script - python -m unittest discover tests +```bash +./runtests.sh -u --net ``` +or (for new features that would not break existing functionality): + +```bash +./runtests.sh --quick --unittests +``` + +It is recommended that the new test `test_[module_name].py` is constructed by using only +python 3.7+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages. +If it requires any other external packages, please make sure: +- the packages are listed in [`requirements-dev.txt`](requirements-dev.txt) +- the new test `test_[module_name].py` is added to the `exclude_cases` in [`./tests/min_tests.py`](./tests/min_tests.py) so that +the minimal CI runner will not execute it. ### Submitting pull requests diff --git a/pyproject.toml b/pyproject.toml index a5762118..ac6bd03e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,60 @@ -[tool.isort] -line_length = 120 -src_paths = ["generative"] -profile = "black" - [tool.black] line-length = 120 +target-version = ['py37', 'py38', 'py39', 'py310'] +include = '\.pyi?$' +exclude = ''' +( + /( + # exclude a few common directories in the root of the project + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | venv + | \.pytype + | _build + | buck-out + | build + | dist + )/ + # also separately exclude a file named versioneer.py + | generative/_version.py +) +''' + +[tool.pycln] +all = true + +[tool.pytype] +# Space-separated list of files or directories to exclude. +exclude = ["versioneer.py", "_version.py"] +# Space-separated list of files or directories to process. +inputs = ["generative"] +# Keep going past errors to analyze as many files as possible. +keep_going = true +# Run N jobs in parallel. +jobs = 8 +# All pytype output goes here. +output = ".pytype" +# Paths to source code directories, separated by ':'. +pythonpath = "." +# Check attribute values against their annotations. +check_attribute_types = true +# Check container mutations against their annotations. +check_container_types = true +# Check parameter defaults and assignments against their annotations. +check_parameter_types = true +# Check variable values against their annotations. +check_variable_types = true +# Comma or space separated list of error names to ignore. +disable = ["pyi-error"] +# Report errors. +report_errors = true +# Experimental: Infer precise return types even for invalid function calls. +precise_return = true +# Experimental: solve unknown types to label with structural types. +protocols = true +# Experimental: Only load submodules that are explicitly imported. +strict_import = false diff --git a/requirements-dev.txt b/requirements-dev.txt index 6c023b84..1a9ddd18 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,16 +1,56 @@ # Full requirements for developments -r requirements-min.txt +pytorch-ignite==0.4.10 +gdown>=4.4.0 +scipy +itk>=5.2 +nibabel +pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 +tensorboard>=2.6 # https://github.com/Project-MONAI/MONAI/issues/5776 +scikit-image>=0.19.0 +tqdm>=4.47.0 +lmdb flake8>=3.8.1 flake8-bugbear flake8-comprehensions flake8-executable +pylint!=2.13 # https://github.com/PyCQA/pylint/issues/5969 +mccabe +pep8-naming +pycodestyle pyflakes black isort +pytype>=2020.6.1; platform_system != "Windows" +types-pkg_resources mypy>=0.790 -pre-commit -matplotlib!=3.5.0 +ninja +torchvision +psutil +Sphinx==3.5.3 +recommonmark==0.6.0 +sphinx-autodoc-typehints==1.11.1 +sphinx-rtd-theme==0.5.2 +cucim==22.8.1; platform_system == "Linux" +openslide-python==1.1.2 +imagecodecs; platform_system == "Linux" or platform_system == "Darwin" +tifffile; platform_system == "Linux" or platform_system == "Darwin" +pandas +requests einops -tensorboard>=2.11.0 -nibabel>=4.0.2 -gdown>=4.4.0 +transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +mlflow +matplotlib!=3.5.0 +tensorboardX +types-PyYAML +pyyaml +fire +jsonschema +pynrrd +pre-commit +pydicom +h5py +nni +optuna +git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded +lpips==0.1.4 diff --git a/requirements-min.txt b/requirements-min.txt index 63906b4a..ad0bb1ef 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -1,5 +1,5 @@ # Requirements for minimal tests -r requirements.txt -setuptools>=50.3.0,!=60.0.0,!=60.6.0 +setuptools>=50.3.0,<66.0.0,!=60.6.0 coverage>=5.5 parameterized diff --git a/requirements.txt b/requirements.txt index 88b7af44..4254669e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -lpips==0.1.4 -monai-weekly==1.1.dev2248 numpy>=1.17 torch>=1.8 -tqdm +monai-weekly==1.1.dev2248 diff --git a/runtests.sh b/runtests.sh new file mode 100755 index 00000000..7b2c8dfd --- /dev/null +++ b/runtests.sh @@ -0,0 +1,698 @@ +#! /bin/bash + +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# script for running all tests +set -e + +# output formatting +separator="" +blue="" +green="" +red="" +noColor="" + +if [[ -t 1 ]] # stdout is a terminal +then + separator=$'--------------------------------------------------------------------------------\n' + blue="$(tput bold; tput setaf 4)" + green="$(tput bold; tput setaf 2)" + red="$(tput bold; tput setaf 1)" + noColor="$(tput sgr0)" +fi + +# configuration values +doCoverage=false +doQuickTests=false +doMinTests=false +doNetTests=false +doDryRun=false +doZooTests=false +doUnitTests=false +doBuild=false +doBlackFormat=false +doBlackFix=false +doIsortFormat=false +doIsortFix=false +doFlake8Format=false +doPylintFormat=false +doClangFormat=false +doCopyRight=false +doPytypeFormat=false +doMypyFormat=false +doCleanup=false +doDistTests=false +doPrecommit=false + +NUM_PARALLEL=1 + +PY_EXE=${MONAI_PY_EXE:-$(which python)} + +function print_usage { + echo "runtests.sh [--codeformat] [--autofix] [--black] [--isort] [--flake8] [--pylint] [--clangformat] [--pytype] [--mypy]" + echo " [--unittests] [--disttests] [--coverage] [--quick] [--min] [--net] [--dryrun] [-j number] [--list_tests]" + echo " [--copyright] [--build] [--clean] [--precommit] [--help] [--version]" + echo "" + echo "MONAI unit testing utilities." + echo "" + echo "Examples:" + echo "./runtests.sh -f -u --net --coverage # run style checks, full tests, print code coverage (${green}recommended for pull requests${noColor})." + echo "./runtests.sh -f -u # run style checks and unit tests." + echo "./runtests.sh -f # run coding style and static type checking." + echo "./runtests.sh --quick --unittests # run minimal unit tests, for quick verification during code developments." + echo "./runtests.sh --autofix # run automatic code formatting using \"isort\" and \"black\"." + echo "./runtests.sh --clean # clean up temporary files and run \"${PY_EXE} setup.py develop --uninstall\"." + echo "" + echo "Code style check options:" + echo " --black : perform \"black\" code format checks" + echo " --autofix : format code using \"isort\" and \"black\"" + echo " --isort : perform \"isort\" import sort checks" + echo " --flake8 : perform \"flake8\" code format checks" + echo " --pylint : perform \"pylint\" code format checks" + echo " --clangformat : format csrc code using \"clang-format\"" + echo " --precommit : perform source code format check and fix using \"pre-commit\"" + echo "" + echo "Python type check options:" + echo " --pytype : perform \"pytype\" static type checks" + echo " --mypy : perform \"mypy\" static type checks" + echo " -j, --jobs : number of parallel jobs to run \"pytype\" (default $NUM_PARALLEL)" + echo "" + echo "MONAI unit testing options:" + echo " -u, --unittests : perform unit testing" + echo " --disttests : perform distributed unit testing" + echo " --coverage : report testing code coverage, to be used with \"--net\", \"--unittests\"" + echo " -q, --quick : skip long running unit tests and integration tests" + echo " -m, --min : only run minimal unit tests which do not require optional packages" + echo " --net : perform integration testing" + echo " -b, --build : compile and install the source code folder an editable release." + echo " --list_tests : list unit tests and exit" + echo "" + echo "Misc. options:" + echo " --dryrun : display the commands to the screen without running" + echo " --copyright : check whether every source code has a copyright header" + echo " -f, --codeformat : shorthand to run all code style and static analysis tests" + echo " -c, --clean : clean temporary files from tests and exit" + echo " -h, --help : show this help message and exit" + echo " -v, --version : show MONAI and system version information and exit" + echo "" + echo "${separator}For bug reports and feature requests, please file an issue at:" + echo " https://github.com/Project-MONAI/MONAI/issues/new/choose" + echo "" + echo "To choose an alternative python executable, set the environmental variable, \"MONAI_PY_EXE\"." + exit 1 +} + +# FIXME: https://github.com/Project-MONAI/MONAI/issues/4354 +protobuf_major_version=$(${PY_EXE} -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1) +if [ "$protobuf_major_version" -ge "4" ] +then + export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +fi + +function check_import { + echo "Python: ${PY_EXE}" + ${cmdPrefix}${PY_EXE} -W error -W ignore::DeprecationWarning -c "import generative" +} + +function print_version { + ${cmdPrefix}${PY_EXE} -c 'import monai; monai.config.print_config()' +} + +function install_deps { + echo "Pip installing MONAI development dependencies and compile MONAI cpp extensions..." + ${cmdPrefix}${PY_EXE} -m pip install -r requirements-dev.txt +} + +function compile_cpp { + echo "Compiling and installing MONAI cpp extensions..." + # depends on setup.py behaviour for building + # currently setup.py uses environment variables: BUILD_MONAI and FORCE_CUDA + ${cmdPrefix}${PY_EXE} setup.py develop --user --uninstall + if [[ "$OSTYPE" == "darwin"* ]]; + then # clang for mac os + CC=clang CXX=clang++ ${cmdPrefix}${PY_EXE} setup.py develop --user + else + ${cmdPrefix}${PY_EXE} setup.py develop --user + fi +} + +function clang_format { + echo "Running clang-format..." + ${cmdPrefix}${PY_EXE} -m tests.clang_format_utils + clang_format_tool='.clang-format-bin/clang-format' + # Verify . + if ! type -p "$clang_format_tool" >/dev/null; then + echo "'clang-format' not found, skipping the formatting." + exit 1 + fi + find generative/csrc -type f | while read i; do $clang_format_tool -style=file -i $i; done + find generative/_extensions -type f -name "*.cpp" -o -name "*.h" -o -name "*.cuh" -o -name "*.cu" |\ + while read i; do $clang_format_tool -style=file -i $i; done +} + +function is_pip_installed() { + return $(${PY_EXE} -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1) +} + +function clean_py { + if is_pip_installed coverage + then + # remove coverage history + ${cmdPrefix}${PY_EXE} -m coverage erase + fi + + # uninstall the development package + echo "Uninstalling MONAI development files..." + ${cmdPrefix}${PY_EXE} setup.py develop --user --uninstall + + # remove temporary files (in the directory of this script) + TO_CLEAN="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + echo "Removing temporary files in ${TO_CLEAN}" + + find ${TO_CLEAN}/generative -type f -name "*.py[co]" -delete + find ${TO_CLEAN}/generative -type f -name "*.so" -delete + find ${TO_CLEAN}/generative -type d -name "__pycache__" -delete + find ${TO_CLEAN} -maxdepth 1 -type f -name ".coverage.*" -delete + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".eggs" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "generative.egg-info" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "build" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "dist" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".mypy_cache" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".pytype" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".coverage" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "__pycache__" -exec rm -r "{}" + +} + +function torch_validate { + ${cmdPrefix}${PY_EXE} -c 'import torch; print(torch.__version__); print(torch.rand(5,3))' +} + +function print_error_msg() { + echo "${red}Error: $1.${noColor}" + echo "" +} + +function print_style_fail_msg() { + echo "${red}Check failed!${noColor}" + echo "Please run auto style fixes: ${green}./runtests.sh --autofix${noColor}" +} + +function list_unittests() { + ${PY_EXE} - << END +import unittest +def print_suite(suite): + if hasattr(suite, "__iter__"): + for x in suite: + print_suite(x) + else: + print(suite) +print_suite(unittest.defaultTestLoader.discover('./tests')) +END + exit 0 +} + +if [ -z "$1" ] +then + print_error_msg "Too few arguments to $0" + print_usage +fi + +# parse arguments +while [[ $# -gt 0 ]] +do + key="$1" + case $key in + --coverage) + doCoverage=true + ;; + -q|--quick) + doQuickTests=true + ;; + -m|--min) + doMinTests=true + ;; + --net) + doNetTests=true + ;; + --list_tests) + list_unittests + ;; + --dryrun) + doDryRun=true + ;; + -u|--u*) # allow --unittest | --unittests | --unittesting etc. + doUnitTests=true + ;; + -f|--codeformat) + doBlackFormat=true + doIsortFormat=true + doFlake8Format=true + doPylintFormat=true + doPytypeFormat=true + doMypyFormat=true + doCopyRight=true + ;; + --disttests) + doDistTests=true + ;; + --black) + doBlackFormat=true + ;; + --autofix) + doIsortFix=true + doBlackFix=true + doIsortFormat=true + doBlackFormat=true + doCopyRight=true + ;; + --clangformat) + doClangFormat=true + ;; + --isort) + doIsortFormat=true + ;; + --flake8) + doFlake8Format=true + ;; + --pylint) + doPylintFormat=true + ;; + --precommit) + doPrecommit=true + ;; + --pytype) + doPytypeFormat=true + ;; + --mypy) + doMypyFormat=true + ;; + -j|--jobs) + NUM_PARALLEL=$2 + shift + ;; + --copyright) + doCopyRight=true + ;; + -b|--build) + doBuild=true + ;; + -c|--clean) + doCleanup=true + ;; + -h|--help) + print_usage + ;; + -v|--version) + print_version + exit 1 + ;; + --nou*) # allow --nounittest | --nounittests | --nounittesting etc. + print_error_msg "nounittest option is deprecated, no unit tests is the default setting" + print_usage + ;; + *) + print_error_msg "Incorrect commandline provided, invalid key: $key" + print_usage + ;; + esac + shift +done + +# home directory +homedir="$( cd -P "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$homedir" + +# python path +export PYTHONPATH="$homedir:$PYTHONPATH" +echo "PYTHONPATH: $PYTHONPATH" + +# by default do nothing +cmdPrefix="" + +if [ $doDryRun = true ] +then + echo "${separator}${blue}dryrun${noColor}" + + # commands are echoed instead of ran + cmdPrefix="dryrun " + function dryrun { echo " " "$@"; } +else + check_import +fi + +if [ $doBuild = true ] +then + echo "${separator}${blue}compile and install${noColor}" + # try to compile MONAI cpp + compile_cpp + + echo "${green}done! (to uninstall and clean up, please use \"./runtests.sh --clean\")${noColor}" +fi + +if [ $doCleanup = true ] +then + echo "${separator}${blue}clean${noColor}" + + clean_py + + echo "${green}done!${noColor}" + exit +fi + +if [ $doClangFormat = true ] +then + echo "${separator}${blue}clang-formatting${noColor}" + + clang_format + + echo "${green}done!${noColor}" +fi + +# unconditionally report on the state of monai +print_version + +if [ $doCopyRight = true ] +then + # check copyright headers + copyright_bad=0 + copyright_all=0 + while read -r fname; do + copyright_all=$((copyright_all + 1)) + if ! grep "http://www.apache.org/licenses/LICENSE-2.0" "$fname" > /dev/null; then + print_error_msg "Missing the license header in file: $fname" + copyright_bad=$((copyright_bad + 1)) + fi + done <<< "$(find "$(pwd)/generative" "$(pwd)/tests" -type f \ + ! -wholename "*_version.py" -and -name "*.py" -or -name "*.cpp" -or -name "*.cu" -or -name "*.h")" + if [[ ${copyright_bad} -eq 0 ]]; + then + echo "${green}Source code copyright headers checked ($copyright_all).${noColor}" + else + echo "Please add the licensing header to the file ($copyright_bad of $copyright_all files)." + echo " See also: https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md#checking-the-coding-style" + echo "" + exit 1 + fi +fi + + +if [ $doPrecommit = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}pre-commit${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed pre_commit + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files + + pre_commit_status=$? + if [ ${pre_commit_status} -ne 0 ] + then + print_style_fail_msg + exit ${pre_commit_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + + +if [ $doIsortFormat = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + if [ $doIsortFix = true ] + then + echo "${separator}${blue}isort-fix${noColor}" + else + echo "${separator}${blue}isort${noColor}" + fi + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed isort + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m isort --version + + if [ $doIsortFix = true ] + then + ${cmdPrefix}${PY_EXE} -m isort "$(pwd)" + else + ${cmdPrefix}${PY_EXE} -m isort --check "$(pwd)" + fi + + isort_status=$? + if [ ${isort_status} -ne 0 ] + then + print_style_fail_msg + exit ${isort_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + + +if [ $doBlackFormat = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + if [ $doBlackFix = true ] + then + echo "${separator}${blue}black-fix${noColor}" + else + echo "${separator}${blue}black${noColor}" + fi + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed black + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m black --version + + if [ $doBlackFix = true ] + then + ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma "$(pwd)" + else + ${cmdPrefix}${PY_EXE} -m black --skip-magic-trailing-comma --check "$(pwd)" + fi + + black_status=$? + if [ ${black_status} -ne 0 ] + then + print_style_fail_msg + exit ${black_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + + +if [ $doFlake8Format = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}flake8${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed flake8 + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m flake8 --version + + ${cmdPrefix}${PY_EXE} -m flake8 "$(pwd)" --count --statistics + + flake8_status=$? + if [ ${flake8_status} -ne 0 ] + then + print_style_fail_msg + exit ${flake8_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + +if [ $doPylintFormat = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}pylint${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed pylint + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m pylint --version + + ignore_codes="C,R,W,E1101,E1102,E0601,E1130,E1123,E0102,E1120,E1137,E1136" + ${cmdPrefix}${PY_EXE} -m pylint generative tests --disable=$ignore_codes -j $NUM_PARALLEL + pylint_status=$? + + if [ ${pylint_status} -ne 0 ] + then + print_style_fail_msg + exit ${pylint_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi + + +if [ $doPytypeFormat = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}pytype${noColor}" + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed pytype + then + install_deps + fi + pytype_ver=$(${cmdPrefix}${PY_EXE} -m pytype --version) + if [[ "$OSTYPE" == "darwin"* && "$pytype_ver" == "2021."* ]]; then + echo "${red}pytype not working on macOS 2021 (https://github.com/Project-MONAI/MONAI/issues/2391). Please upgrade to 2022*.${noColor}" + exit 1 + else + ${cmdPrefix}${PY_EXE} -m pytype --version + + ${cmdPrefix}${PY_EXE} -m pytype -j ${NUM_PARALLEL} --python-version="$(${PY_EXE} -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" "$(pwd)" + + pytype_status=$? + if [ ${pytype_status} -ne 0 ] + then + echo "${red}failed!${noColor}" + exit ${pytype_status} + else + echo "${green}passed!${noColor}" + fi + fi + set -e # enable exit on failure +fi + + +if [ $doMypyFormat = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}mypy${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed mypy + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m mypy --version + ${cmdPrefix}${PY_EXE} -m mypy "$(pwd)" + + mypy_status=$? + if [ ${mypy_status} -ne 0 ] + then + : # mypy output already follows format + exit ${mypy_status} + else + : # mypy output already follows format + fi + set -e # enable exit on failure +fi + + +# testing command to run +cmd="${PY_EXE}" + +# When running --quick, require doCoverage as well and set QUICKTEST environmental +# variable to disable slow unit tests from running. +if [ $doQuickTests = true ] +then + echo "${separator}${blue}quick${noColor}" + doCoverage=true + export QUICKTEST=True +fi + +if [ $doMinTests = true ] +then + echo "${separator}${blue}min${noColor}" + ${cmdPrefix}${PY_EXE} -m tests.min_tests +fi + +# set coverage command +if [ $doCoverage = true ] +then + echo "${separator}${blue}coverage${noColor}" + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed coverage + then + install_deps + fi + cmd="${PY_EXE} -m coverage run --append" +fi + +# # download test data if needed +# if [ ! -d testing_data ] && [ "$doDryRun" != 'true' ] +# then +# fi + +# unit tests +if [ $doUnitTests = true ] +then + echo "${separator}${blue}unittests${noColor}" + torch_validate + ${cmdPrefix}${cmd} ./tests/runner.py -p "^(?!test_integration).*(?= threshold + results = dict(filter(lambda x: x[1] > thresh, results.items())) + if len(results) == 0: + return + print(f"\n\n{status}, printing completed times >{thresh}s in ascending order...\n") + timings = dict(sorted(results.items(), key=lambda item: item[1])) + + for r in timings: + if timings[r] >= thresh: + print(f"{r} ({timings[r]:.03}s)") + print(f"test discovery time: {discovery_time:.03}s") + print(f"total testing time: {sum(results.values()):.03}s") + print("Remember to check above times for any errors!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Runner for MONAI unittests with timing.") + parser.add_argument( + "-s", action="store", dest="path", default=".", help="Directory to start discovery (default: '%(default)s')" + ) + parser.add_argument( + "-p", + action="store", + dest="pattern", + default="test_*.py", + help="Pattern to match tests (default: '%(default)s')", + ) + parser.add_argument( + "-t", + "--thresh", + dest="thresh", + default=10.0, + type=float, + help="Display tests longer than given threshold (default: %(default)d)", + ) + parser.add_argument( + "-v", + "--verbosity", + action="store", + dest="verbosity", + type=int, + default=1, + help="Verbosity level (default: %(default)d)", + ) + parser.add_argument("-q", "--quick", action="store_true", dest="quick", default=False, help="Only do quick tests") + parser.add_argument( + "-f", "--failfast", action="store_true", dest="failfast", default=False, help="Stop testing on first failure" + ) + args = parser.parse_args() + print(f"Running tests in folder: '{args.path}'") + if args.pattern: + print(f"With file pattern: '{args.pattern}'") + + return args + + +def get_default_pattern(loader): + signature = inspect.signature(loader.discover) + params = {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} + return params["pattern"] + + +if __name__ == "__main__": + + # Parse input arguments + args = parse_args() + + # If quick is desired, set environment variable + if args.quick: + os.environ["QUICKTEST"] = "True" + + # Get all test names (optionally from some path with some pattern) + with PerfContext() as pc: + # the files are searched from `tests/` folder, starting with `test_` + files = glob.glob(os.path.join(os.path.dirname(__file__), "test_*.py")) + cases = [] + for test_module in {os.path.basename(f)[:-3] for f in files}: + if re.match(args.pattern, test_module): + cases.append(f"tests.{test_module}") + else: + print(f"monai test runner: excluding tests.{test_module}") + tests = unittest.TestLoader().loadTestsFromNames(cases) + discovery_time = pc.total_time + print(f"time to discover tests: {discovery_time}s, total cases: {tests.countTestCases()}.") + + test_runner = unittest.runner.TextTestRunner( + resultclass=TimeLoggingTestResult, verbosity=args.verbosity, failfast=args.failfast + ) + + # Use try catches to print the current results if encountering exception or keyboard interruption + try: + test_result = test_runner.run(tests) + print_results(results, discovery_time, args.thresh, "tests finished") + sys.exit(not test_result.wasSuccessful()) + except KeyboardInterrupt: + print_results(results, discovery_time, args.thresh, "tests cancelled") + sys.exit(1) + except Exception: + print_results(results, discovery_time, args.thresh, "exception reached") + raise diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..cb1cabdc --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,810 @@ +# COPIED FROM https://github.com/Project-MONAI/MONAI/blob/fdd07f36ecb91cfcd491533f4792e1a67a9f89fc/tests/utils.py +# --------------------------------------------------------------- + +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import datetime +import functools +import importlib +import json +import operator +import os +import queue +import ssl +import subprocess +import sys +import tempfile +import time +import traceback +import unittest +import warnings +from contextlib import contextmanager +from functools import partial, reduce +from subprocess import PIPE, Popen +from typing import Callable +from urllib.error import ContentTooShortError, HTTPError + +import numpy as np +import torch +import torch.distributed as dist +from monai.apps.utils import download_url +from monai.config import NdarrayTensor +from monai.config.deviceconfig import USE_COMPILED +from monai.config.type_definitions import NdarrayOrTensor +from monai.data import create_test_image_2d, create_test_image_3d +from monai.data.meta_tensor import MetaTensor, get_track_meta +from monai.networks import convert_to_torchscript +from monai.utils import optional_import +from monai.utils.module import pytorch_after, version_leq +from monai.utils.type_conversion import convert_data_type + +nib, _ = optional_import("nibabel") +http_error, has_requests = optional_import("requests", name="HTTPError") + +quick_test_var = "QUICKTEST" +_tf32_enabled = None +_test_data_config: dict = {} + + +def testing_data_config(*keys): + """get _test_data_config[keys0][keys1]...[keysN]""" + if not _test_data_config: + with open(os.path.join(os.path.dirname(__file__), "testing_data", "data_config.json")) as c: + _config = json.load(c) + for k, v in _config.items(): + _test_data_config[k] = v + return reduce(operator.getitem, keys, _test_data_config) + + +def clone(data: NdarrayTensor) -> NdarrayTensor: + """ + Clone data independent of type. + + Args: + data (NdarrayTensor): This can be a Pytorch Tensor or numpy array. + + Returns: + Any: Cloned data object + """ + return copy.deepcopy(data) + + +def assert_allclose( + actual: NdarrayOrTensor, + desired: NdarrayOrTensor, + type_test: bool | str = True, + device_test: bool = False, + *args, + **kwargs, +): + """ + Assert that types and all values of two data objects are close. + + Args: + actual: Pytorch Tensor or numpy array for comparison. + desired: Pytorch Tensor or numpy array to compare against. + type_test: whether to test that `actual` and `desired` are both numpy arrays or torch tensors. + if type_test == "tensor", it checks whether the `actual` is a torch.tensor or metatensor according to + `get_track_meta`. + device_test: whether to test the device property. + args: extra arguments to pass on to `np.testing.assert_allclose`. + kwargs: extra arguments to pass on to `np.testing.assert_allclose`. + + + """ + if isinstance(type_test, str) and type_test == "tensor": + if get_track_meta(): + np.testing.assert_equal(isinstance(actual, MetaTensor), True, "must be a MetaTensor") + else: + np.testing.assert_equal( + isinstance(actual, torch.Tensor) and not isinstance(actual, MetaTensor), True, "must be a torch.Tensor" + ) + elif type_test: + # check both actual and desired are of the same type + np.testing.assert_equal(isinstance(actual, np.ndarray), isinstance(desired, np.ndarray), "numpy type") + np.testing.assert_equal(isinstance(actual, torch.Tensor), isinstance(desired, torch.Tensor), "torch type") + + if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor): + if device_test: + np.testing.assert_equal(str(actual.device), str(desired.device), "torch device check") # type: ignore + actual = actual.detach().cpu().numpy() if isinstance(actual, torch.Tensor) else actual + desired = desired.detach().cpu().numpy() if isinstance(desired, torch.Tensor) else desired + np.testing.assert_allclose(actual, desired, *args, **kwargs) + + +@contextmanager +def skip_if_downloading_fails(): + try: + yield + except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_requests else () as e: + raise unittest.SkipTest(f"error while downloading: {e}") from e + except ssl.SSLError as ssl_e: + if "decryption failed" in str(ssl_e): + raise unittest.SkipTest(f"SSL error while downloading: {ssl_e}") from ssl_e + except (RuntimeError, OSError) as rt_e: + err_str = str(rt_e) + if any( + k in err_str + for k in ( + "unexpected EOF", # incomplete download + "network issue", + "gdown dependency", # gdown not installed + "md5 check", + "limit", # HTTP Error 503: Egress is over the account limit + "authenticate", + ) + ): + raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download + + raise rt_e + + +def test_pretrained_networks(network, input_param, device): + with skip_if_downloading_fails(): + return network(**input_param).to(device) + + +def test_is_quick(): + return os.environ.get(quick_test_var, "").lower() == "true" + + +def is_tf32_env(): + """ + The environment variable NVIDIA_TF32_OVERRIDE=0 will override any defaults + or programmatic configuration of NVIDIA libraries, and consequently, + cuBLAS will not accelerate FP32 computations with TF32 tensor cores. + """ + global _tf32_enabled + if _tf32_enabled is None: + _tf32_enabled = False + if ( + torch.cuda.is_available() + and not version_leq(f"{torch.version.cuda}", "10.100") + and os.environ.get("NVIDIA_TF32_OVERRIDE", "1") != "0" + and torch.cuda.device_count() > 0 # at least 11.0 + ): + try: + # with TF32 enabled, the speed is ~8x faster, but the precision has ~2 digits less in the result + g_gpu = torch.Generator(device="cuda") + g_gpu.manual_seed(2147483647) + a_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) + b_full = torch.randn(1024, 1024, dtype=torch.double, device="cuda", generator=g_gpu) + _tf32_enabled = (a_full.float() @ b_full.float() - a_full @ b_full).abs().max().item() > 0.001 # 0.1713 + except BaseException: + pass + print(f"tf32 enabled: {_tf32_enabled}") + return _tf32_enabled + + +def skip_if_quick(obj): + """ + Skip the unit tests if environment variable `quick_test_var=true`. + For example, the user can skip the relevant tests by setting ``export QUICKTEST=true``. + """ + is_quick = test_is_quick() + + return unittest.skipIf(is_quick, "Skipping slow tests")(obj) + + +class SkipIfNoModule: + """Decorator to be used if test should be skipped + when optional module is not present.""" + + def __init__(self, module_name): + self.module_name = module_name + self.module_missing = not optional_import(self.module_name)[1] + + def __call__(self, obj): + return unittest.skipIf(self.module_missing, f"optional module not present: {self.module_name}")(obj) + + +class SkipIfModule: + """Decorator to be used if test should be skipped + when optional module is present.""" + + def __init__(self, module_name): + self.module_name = module_name + self.module_avail = optional_import(self.module_name)[1] + + def __call__(self, obj): + return unittest.skipIf(self.module_avail, f"Skipping because optional module present: {self.module_name}")(obj) + + +def skip_if_no_cpp_extension(obj): + """ + Skip the unit tests if the cpp extension is not available. + """ + return unittest.skipUnless(USE_COMPILED, "Skipping cpp extension tests")(obj) + + +def skip_if_no_cuda(obj): + """ + Skip the unit tests if torch.cuda.is_available is False. + """ + return unittest.skipUnless(torch.cuda.is_available(), "Skipping CUDA-based tests")(obj) + + +def skip_if_windows(obj): + """ + Skip the unit tests if platform is win32. + """ + return unittest.skipIf(sys.platform == "win32", "Skipping tests on Windows")(obj) + + +def skip_if_darwin(obj): + """ + Skip the unit tests if platform is macOS (Darwin). + """ + return unittest.skipIf(sys.platform == "darwin", "Skipping tests on macOS/Darwin")(obj) + + +class SkipIfBeforePyTorchVersion: + """Decorator to be used if test should be skipped + with PyTorch versions older than that given.""" + + def __init__(self, pytorch_version_tuple): + self.min_version = pytorch_version_tuple + self.version_too_old = not pytorch_after(*pytorch_version_tuple) + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_old, f"Skipping tests that fail on PyTorch versions before: {self.min_version}" + )(obj) + + +class SkipIfAtLeastPyTorchVersion: + """Decorator to be used if test should be skipped + with PyTorch versions newer than or equal to that given.""" + + def __init__(self, pytorch_version_tuple): + self.max_version = pytorch_version_tuple + self.version_too_new = pytorch_after(*pytorch_version_tuple) + + def __call__(self, obj): + return unittest.skipIf( + self.version_too_new, f"Skipping tests that fail on PyTorch versions at least: {self.max_version}" + )(obj) + + +def is_main_test_process(): + ps = torch.multiprocessing.current_process() + if not ps or not hasattr(ps, "name"): + return False + return ps.name.startswith("Main") + + +def has_cupy(): + """ + Returns True if the user has installed a version of cupy. + """ + cp, has_cp = optional_import("cupy") + if not is_main_test_process(): + return has_cp # skip the check if we are running in subprocess + if not has_cp: + return False + try: # test cupy installation with a basic example + x = cp.arange(6, dtype="f").reshape(2, 3) + y = cp.arange(3, dtype="f") + kernel = cp.ElementwiseKernel( + "float32 x, float32 y", "float32 z", """ if (x - 2 > y) { z = x * y; } else { z = x + y; } """, "my_kernel" + ) + flag = kernel(x, y)[0, 0] == 0 + del x, y, kernel + cp.get_default_memory_pool().free_all_blocks() + return flag + except Exception: + return False + + +HAS_CUPY = has_cupy() + + +def make_nifti_image( + array: NdarrayOrTensor, affine=None, dir=None, fname=None, suffix=".nii.gz", verbose=False, dtype=float +): + """ + Create a temporary nifti image on the disk and return the image name. + User is responsible for deleting the temporary file when done with it. + """ + if isinstance(array, torch.Tensor): + array, *_ = convert_data_type(array, np.ndarray) + if isinstance(affine, torch.Tensor): + affine, *_ = convert_data_type(affine, np.ndarray) + if affine is None: + affine = np.eye(4) + test_image = nib.Nifti1Image(array.astype(dtype), affine) # type: ignore + + # if dir not given, create random. Else, make sure it exists. + if dir is None: + dir = tempfile.mkdtemp() + else: + os.makedirs(dir, exist_ok=True) + + # If fname not given, get random one. Else, concat dir, fname and suffix. + if fname is None: + temp_f, fname = tempfile.mkstemp(suffix=suffix, dir=dir) + os.close(temp_f) + else: + fname = os.path.join(dir, fname + suffix) + + nib.save(test_image, fname) + if verbose: + print(f"File written: {fname}.") + return fname + + +def make_rand_affine(ndim: int = 3, random_state: np.random.RandomState | None = None): + """Create random affine transformation (with values == -1, 0 or 1).""" + rs = np.random.random.__self__ if random_state is None else random_state # type: ignore + + vals = rs.choice([-1, 1], size=ndim) + positions = rs.choice(range(ndim), size=ndim, replace=False) + af = np.zeros([ndim + 1, ndim + 1]) + af[ndim, ndim] = 1 + for i, (v, p) in enumerate(zip(vals, positions)): + af[i, p] = v + return af + + +def get_arange_img(size, dtype=np.float32, offset=0): + """ + Returns an image as a numpy array (complete with channel as dim 0) + with contents that iterate like an arange. + """ + n_elem = np.prod(size) + img = np.arange(offset, offset + n_elem, dtype=dtype).reshape(size) + return np.expand_dims(img, 0) + + +class DistTestCase(unittest.TestCase): + """ + testcase without _outcome, so that it's picklable. + """ + + def __getstate__(self): + self_dict = self.__dict__.copy() + del self_dict["_outcome"] + return self_dict + + def __setstate__(self, data_dict): + self.__dict__.update(data_dict) + + +class DistCall: + """ + Wrap a test case so that it will run in multiple processes on a single machine using `torch.distributed`. + It is designed to be used with `tests.utils.DistTestCase`. + + Usage: + + decorate a unittest testcase method with a `DistCall` instance:: + + class MyTests(unittest.TestCase): + @DistCall(nnodes=1, nproc_per_node=3, master_addr="localhost") + def test_compute(self): + ... + + the `test_compute` method should trigger different worker logic according to `dist.get_rank()`. + + Multi-node tests require a fixed master_addr:master_port, with node_rank set manually in multiple scripts + or from environment variable "NODE_RANK". + """ + + def __init__( + self, + nnodes: int = 1, + nproc_per_node: int = 1, + master_addr: str = "localhost", + master_port: int | None = None, + node_rank: int | None = None, + timeout=60, + init_method=None, + backend: str | None = None, + daemon: bool | None = None, + method: str | None = "spawn", + verbose: bool = False, + ): + """ + + Args: + nnodes: The number of nodes to use for distributed call. + nproc_per_node: The number of processes to call on each node. + master_addr: Master node (rank 0)'s address, should be either the IP address or the hostname of node 0. + master_port: Master node (rank 0)'s free port. + node_rank: The rank of the node, this could be set via environment variable "NODE_RANK". + timeout: Timeout for operations executed against the process group. + init_method: URL specifying how to initialize the process group. + Default is "env://" or "file:///d:/a_temp" (windows) if unspecified. + If ``"no_init"``, the `dist.init_process_group` must be called within the code to be tested. + backend: The backend to use. Depending on build-time configurations, + valid values include ``mpi``, ``gloo``, and ``nccl``. + daemon: the process’s daemon flag. + When daemon=None, the initial value is inherited from the creating process. + method: set the method which should be used to start a child process. + method can be 'fork', 'spawn' or 'forkserver'. + verbose: whether to print NCCL debug info. + """ + self.nnodes = int(nnodes) + self.nproc_per_node = int(nproc_per_node) + if self.nnodes < 1 or self.nproc_per_node < 1: + raise ValueError( + f"number of nodes and processes per node must be >= 1, got {self.nnodes} and {self.nproc_per_node}" + ) + self.node_rank = int(os.environ.get("NODE_RANK", "0")) if node_rank is None else int(node_rank) + self.master_addr = master_addr + self.master_port = np.random.randint(10000, 20000) if master_port is None else master_port + + if backend is None: + self.backend = "nccl" if torch.distributed.is_nccl_available() and torch.cuda.is_available() else "gloo" + else: + self.backend = backend + self.init_method = init_method + if self.init_method is None and sys.platform == "win32": + self.init_method = "file:///d:/a_temp" + self.timeout = datetime.timedelta(0, timeout) + self.daemon = daemon + self.method = method + self.verbose = verbose + + def run_process(self, func, local_rank, args, kwargs, results): + _env = os.environ.copy() # keep the original system env + try: + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["LOCAL_RANK"] = str(local_rank) + if self.verbose: + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_SUBSYS"] = "ALL" + os.environ["NCCL_BLOCKING_WAIT"] = str(1) + os.environ["OMP_NUM_THREADS"] = str(1) + os.environ["WORLD_SIZE"] = str(self.nproc_per_node * self.nnodes) + os.environ["RANK"] = str(self.nproc_per_node * self.node_rank + local_rank) + + if torch.cuda.is_available(): + torch.cuda.set_device(int(local_rank)) # using device ids from CUDA_VISIBILE_DEVICES + + if self.init_method != "no_init": + dist.init_process_group( + backend=self.backend, + init_method=self.init_method, + timeout=self.timeout, + world_size=int(os.environ["WORLD_SIZE"]), + rank=int(os.environ["RANK"]), + ) + func(*args, **kwargs) + # the primary node lives longer to + # avoid _store_based_barrier, RuntimeError: Broken pipe + # as the TCP store daemon is on the rank 0 + if int(os.environ["RANK"]) == 0: + time.sleep(0.1) + results.put(True) + except Exception as e: + results.put(False) + raise e + finally: + os.environ.clear() + os.environ.update(_env) + try: + dist.destroy_process_group() + except RuntimeError as e: + warnings.warn(f"While closing process group: {e}.") + + def __call__(self, obj): + if not torch.distributed.is_available(): + return unittest.skipIf(True, "Skipping distributed tests because not torch.distributed.is_available()")(obj) + if torch.cuda.is_available() and torch.cuda.device_count() < self.nproc_per_node: + return unittest.skipIf( + True, + f"Skipping distributed tests because it requires {self.nnodes} devices " + f"but got {torch.cuda.device_count()}", + )(obj) + + _cache_original_func(obj) + + @functools.wraps(obj) + def _wrapper(*args, **kwargs): + tmp = torch.multiprocessing.get_context(self.method) + processes = [] + results = tmp.Queue() + func = _call_original_func + args = [obj.__name__, obj.__module__] + list(args) + for proc_rank in range(self.nproc_per_node): + p = tmp.Process( + target=self.run_process, args=(func, proc_rank, args, kwargs, results), daemon=self.daemon + ) + p.start() + processes.append(p) + for p in processes: + p.join() + assert results.get(), "Distributed call failed." + _del_original_func(obj) + + return _wrapper + + +class TimedCall: + """ + Wrap a test case so that it will run in a new process, raises a TimeoutError if the decorated method takes + more than `seconds` to finish. It is designed to be used with `tests.utils.DistTestCase`. + """ + + def __init__( + self, + seconds: float = 60.0, + daemon: bool | None = None, + method: str | None = "spawn", + force_quit: bool = True, + skip_timing=False, + ): + """ + + Args: + seconds: timeout seconds. + daemon: the process’s daemon flag. + When daemon=None, the initial value is inherited from the creating process. + method: set the method which should be used to start a child process. + method can be 'fork', 'spawn' or 'forkserver'. + force_quit: whether to terminate the child process when `seconds` elapsed. + skip_timing: whether to skip the timing constraint. + this is useful to include some system conditions such as + `torch.cuda.is_available()`. + """ + self.timeout_seconds = seconds + self.daemon = daemon + self.force_quit = force_quit + self.skip_timing = skip_timing + self.method = method + + @staticmethod + def run_process(func, args, kwargs, results): + try: + output = func(*args, **kwargs) + results.put(output) + except Exception as e: + e.traceback = traceback.format_exc() + results.put(e) + + def __call__(self, obj): + + if self.skip_timing: + return obj + + _cache_original_func(obj) + + @functools.wraps(obj) + def _wrapper(*args, **kwargs): + tmp = torch.multiprocessing.get_context(self.method) + func = _call_original_func + args = [obj.__name__, obj.__module__] + list(args) + results = tmp.Queue() + p = tmp.Process(target=TimedCall.run_process, args=(func, args, kwargs, results), daemon=self.daemon) + p.start() + + p.join(timeout=self.timeout_seconds) + + timeout_error = None + try: + if p.is_alive(): + # create an Exception + timeout_error = torch.multiprocessing.TimeoutError( + f"'{obj.__name__}' in '{obj.__module__}' did not finish in {self.timeout_seconds}s." + ) + if self.force_quit: + p.terminate() + else: + warnings.warn( + f"TimedCall: deadline ({self.timeout_seconds}s) " + f"reached but waiting for {obj.__name__} to finish." + ) + finally: + p.join() + + _del_original_func(obj) + res = None + try: + res = results.get(block=False) + except queue.Empty: # no result returned, took too long + pass + if isinstance(res, Exception): # other errors from obj + if hasattr(res, "traceback"): + raise RuntimeError(res.traceback) from res + raise res + if timeout_error: # no force_quit finished + raise timeout_error + return res + + return _wrapper + + +_original_funcs = {} + + +def _cache_original_func(obj) -> None: + """cache the original function by name, so that the decorator doesn't shadow it.""" + _original_funcs[obj.__name__] = obj + + +def _del_original_func(obj): + """pop the original function from cache.""" + _original_funcs.pop(obj.__name__, None) + if torch.cuda.is_available(): # clean up the cached function + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def _call_original_func(name, module, *args, **kwargs): + if name not in _original_funcs: + _original_module = importlib.import_module(module) # reimport, refresh _original_funcs + if not hasattr(_original_module, name): + # refresh module doesn't work + raise RuntimeError(f"Could not recover the original {name} from {module}: {_original_funcs}.") + f = _original_funcs[name] + return f(*args, **kwargs) + + +class NumpyImageTestCase2D(unittest.TestCase): + im_shape = (128, 64) + input_channels = 1 + output_channels = 4 + num_classes = 3 + + def setUp(self): + im, msk = create_test_image_2d( + self.im_shape[0], self.im_shape[1], num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=self.num_classes + ) + + self.imt = im[None, None] + self.seg1 = (msk[None, None] > 0).astype(np.float32) + self.segn = msk[None, None] + + +class TorchImageTestCase2D(NumpyImageTestCase2D): + def setUp(self): + NumpyImageTestCase2D.setUp(self) + self.imt = torch.tensor(self.imt) + self.seg1 = torch.tensor(self.seg1) + self.segn = torch.tensor(self.segn) + + +class NumpyImageTestCase3D(unittest.TestCase): + im_shape = (64, 48, 80) + input_channels = 1 + output_channels = 4 + num_classes = 3 + + def setUp(self): + im, msk = create_test_image_3d( + self.im_shape[0], + self.im_shape[1], + self.im_shape[2], + num_objs=4, + rad_max=20, + noise_max=0.0, + num_seg_classes=self.num_classes, + ) + + self.imt = im[None, None] + self.seg1 = (msk[None, None] > 0).astype(np.float32) + self.segn = msk[None, None] + + +class TorchImageTestCase3D(NumpyImageTestCase3D): + def setUp(self): + NumpyImageTestCase3D.setUp(self) + self.imt = torch.tensor(self.imt) + self.seg1 = torch.tensor(self.seg1) + self.segn = torch.tensor(self.segn) + + +def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0): + """ + Test the ability to save `net` as a Torchscript object, reload it, and apply inference. The value `inputs` is + forward-passed through the original and loaded copy of the network and their results returned. + The forward pass for both is done without gradient accumulation. + + The test will be performed with CUDA if available, else CPU. + """ + # TODO: would be nice to use GPU if available, but it currently causes CI failures. + device = "cpu" + try: + with tempfile.TemporaryDirectory() as tempdir: + convert_to_torchscript( + model=net, + filename_or_obj=os.path.join(tempdir, "model.ts"), + verify=True, + inputs=inputs, + device=device, + rtol=rtol, + atol=atol, + ) + except (RuntimeError, AttributeError): + if sys.version_info.major == 3 and sys.version_info.minor == 11: + warnings.warn("skipping py 3.11") + return + + +def download_url_or_skip_test(*args, **kwargs): + """``download_url`` and skip the tests if any downloading error occurs.""" + with skip_if_downloading_fails(): + download_url(*args, **kwargs) + + +def query_memory(n=2): + """ + Find best n idle devices and return a string of device ids using the `nvidia-smi` command. + """ + bash_string = "nvidia-smi --query-gpu=power.draw,temperature.gpu,memory.used --format=csv,noheader,nounits" + + try: + p1 = Popen(bash_string.split(), stdout=PIPE) + output, error = p1.communicate() + free_memory = [x.split(",") for x in output.decode("utf-8").split("\n")[:-1]] + free_memory = np.asarray(free_memory, dtype=float).T + free_memory[1] += free_memory[0] # combine 0/1 column measures + ids = np.lexsort(free_memory)[:n] + except (TypeError, ValueError, IndexError, OSError): + ids = range(n) if isinstance(n, int) else [] + return ",".join(f"{int(x)}" for x in ids) + + +def test_local_inversion(invertible_xform, to_invert, im, dict_key=None): + """test that invertible_xform can bring to_invert back to im""" + im_item = im if dict_key is None else im[dict_key] + if not isinstance(im_item, MetaTensor): + return + im_ref = copy.deepcopy(im) + im_inv = invertible_xform.inverse(to_invert) + if dict_key: + im_inv = im_inv[dict_key] + im_ref = im_ref[dict_key] + np.testing.assert_array_equal(im_inv.applied_operations, []) + assert_allclose(im_inv.shape, im_ref.shape) + assert_allclose(im_inv.affine, im_ref.affine, atol=1e-3, rtol=1e-3) + + +def command_line_tests(cmd, copy_env=True): + test_env = os.environ.copy() if copy_env else os.environ + print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) + try: + normal_out = subprocess.run(cmd, env=test_env, check=True, capture_output=True) + print(repr(normal_out).replace("\\n", "\n").replace("\\t", "\t")) + except subprocess.CalledProcessError as e: + output = repr(e.stdout).replace("\\n", "\n").replace("\\t", "\t") + errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t") + raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e + + +TEST_TORCH_TENSORS: tuple = (torch.as_tensor,) +if torch.cuda.is_available(): + gpu_tensor: Callable = partial(torch.as_tensor, device="cuda") + TEST_TORCH_TENSORS = TEST_TORCH_TENSORS + (gpu_tensor,) + +DEFAULT_TEST_AFFINE = torch.tensor( + [[2.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 2.0, 0.0], [0.0, 0.0, 0.0, 1.0]] +) +_metatensor_creator = partial(MetaTensor, meta={"a": "b", "affine": DEFAULT_TEST_AFFINE}) +TEST_NDARRAYS_NO_META_TENSOR: tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS # type: ignore +TEST_NDARRAYS: tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,) # type: ignore +TEST_TORCH_AND_META_TENSORS: tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # type: ignore +# alias for branch tests +TEST_NDARRAYS_ALL = TEST_NDARRAYS + +TEST_DEVICES = [[torch.device("cpu")]] +if torch.cuda.is_available(): + TEST_DEVICES.append([torch.device("cuda")]) + +if __name__ == "__main__": + print("\n", query_memory(), sep="\n") # print to stdout + sys.exit(0)