diff --git a/.circleci/cimodel/data/caffe2_build_definitions.py b/.circleci/cimodel/data/caffe2_build_definitions.py index 6a712002e28dc..c58492a8e1a79 100644 --- a/.circleci/cimodel/data/caffe2_build_definitions.py +++ b/.circleci/cimodel/data/caffe2_build_definitions.py @@ -14,7 +14,7 @@ DOCKER_IMAGE_PATH_BASE = "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/" -DOCKER_IMAGE_VERSION = 315 +DOCKER_IMAGE_VERSION = 325 @dataclass diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index ef328d1668014..9ebb2334dba5d 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -8,7 +8,7 @@ (None, [ XImportant("2.7.9"), X("2.7"), - X("3.5"), + XImportant("3.5"), # Not run on all PRs, but should be included on [test all] X("nightly"), ]), ("gcc", [ diff --git a/.circleci/config.yml b/.circleci/config.yml index fcc1c5df7eb95..5f5bf784d70f6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -551,12 +551,7 @@ jobs: # Reinitialize path (see man page for path_helper(8)) eval `/usr/libexec/path_helper -s` - # Use Homebrew Python if configured to do so - if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then - export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH - fi - - pip -q install numpy + export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH # Install Anaconda if we need to if [ -n "${CAFFE2_USE_ANACONDA}" ]; then @@ -569,6 +564,8 @@ jobs: source ${TMPDIR}/anaconda/bin/activate fi + pip -q install numpy + # Install sccache sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache @@ -606,7 +603,6 @@ jobs: if which sccache > /dev/null; then sccache --show-stats fi - binary_linux_build: <<: *binary_linux_build_params steps: @@ -1333,7 +1329,7 @@ jobs: export PATH="~/anaconda/bin:${PATH}" source ~/anaconda/bin/activate # Install dependencies - conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests + conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes # sync submodules cd ${PROJ_ROOT} git submodule sync @@ -1518,11 +1514,6 @@ workflows: name: pytorch_linux_xenial_py3_5_build requires: - setup - filters: - branches: - only: - - master - - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.5-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347" - pytorch_linux_test: @@ -1530,11 +1521,6 @@ workflows: requires: - setup - pytorch_linux_xenial_py3_5_build - filters: - branches: - only: - - master - - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.5-test" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347" resource_class: large @@ -1944,7 +1930,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:325" - caffe2_linux_test: name: caffe2_py2_gcc4_8_ubuntu14_04_test requires: @@ -1956,7 +1942,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:325" resource_class: large - caffe2_linux_build: name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_build @@ -1968,7 +1954,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:325" - caffe2_linux_test: name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_test requires: @@ -1981,14 +1967,14 @@ workflows: - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:325" resource_class: gpu.medium - caffe2_linux_build: name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build requires: - setup build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:325" - caffe2_linux_test: name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_test requires: @@ -1996,14 +1982,14 @@ workflows: - caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:325" resource_class: gpu.medium - caffe2_linux_build: name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build requires: - setup build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:325" - caffe2_linux_test: name: caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_test requires: @@ -2011,35 +1997,35 @@ workflows: - caffe2_py3_5_cuda10_1_cudnn7_ubuntu16_04_build build_environment: "caffe2-py3.5-cuda10.1-cudnn7-ubuntu16.04-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.5-cuda10.1-cudnn7-ubuntu16.04:325" resource_class: gpu.medium - caffe2_linux_build: name: caffe2_py2_mkl_ubuntu16_04_build requires: - setup build_environment: "caffe2-py2-mkl-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:325" - caffe2_linux_test: name: caffe2_py2_mkl_ubuntu16_04_test requires: - setup - caffe2_py2_mkl_ubuntu16_04_build build_environment: "caffe2-py2-mkl-ubuntu16.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:325" resource_class: large - caffe2_linux_build: name: caffe2_onnx_py2_gcc5_ubuntu16_04_build requires: - setup build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:325" - caffe2_linux_test: name: caffe2_onnx_py2_gcc5_ubuntu16_04_test requires: - setup - caffe2_onnx_py2_gcc5_ubuntu16_04_build build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:325" resource_class: large - caffe2_linux_build: name: caffe2_py2_clang3_8_ubuntu16_04_build @@ -2051,7 +2037,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-clang3.8-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:325" build_only: "1" - caffe2_linux_build: name: caffe2_py2_clang3_9_ubuntu16_04_build @@ -2063,35 +2049,35 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-clang3.9-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:325" build_only: "1" - caffe2_linux_build: name: caffe2_py2_clang7_ubuntu16_04_build requires: - setup build_environment: "caffe2-py2-clang7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:325" build_only: "1" - caffe2_linux_build: name: caffe2_onnx_py3_6_clang7_ubuntu16_04_build requires: - setup build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:325" - caffe2_linux_test: name: caffe2_onnx_py3_6_clang7_ubuntu16_04_test requires: - setup - caffe2_onnx_py3_6_clang7_ubuntu16_04_build build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:325" resource_class: large - caffe2_linux_build: name: caffe2_py2_android_ubuntu16_04_build requires: - setup build_environment: "caffe2-py2-android-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:325" build_only: "1" - caffe2_linux_build: name: caffe2_py2_cuda9_0_cudnn7_centos7_build @@ -2103,7 +2089,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:325" - caffe2_linux_test: name: caffe2_py2_cuda9_0_cudnn7_centos7_test requires: @@ -2116,7 +2102,7 @@ workflows: - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:315" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:325" resource_class: gpu.medium - caffe2_macos_build: name: caffe2_py2_ios_macos10_13_build diff --git a/.circleci/scripts/binary_ios_build.sh b/.circleci/scripts/binary_ios_build.sh index c15813b5c5d71..900df30ec60a5 100644 --- a/.circleci/scripts/binary_ios_build.sh +++ b/.circleci/scripts/binary_ios_build.sh @@ -13,7 +13,7 @@ chmod +x ~/Downloads/conda.sh export PATH="~/anaconda/bin:${PATH}" source ~/anaconda/bin/activate # Install dependencies -conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests +conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} # sync submodules cd ${PROJ_ROOT} diff --git a/.circleci/verbatim-sources/caffe2-job-specs.yml b/.circleci/verbatim-sources/caffe2-job-specs.yml index 4e46a269e523c..09536c51394ac 100644 --- a/.circleci/verbatim-sources/caffe2-job-specs.yml +++ b/.circleci/verbatim-sources/caffe2-job-specs.yml @@ -146,12 +146,7 @@ # Reinitialize path (see man page for path_helper(8)) eval `/usr/libexec/path_helper -s` - # Use Homebrew Python if configured to do so - if [ "${PYTHON_INSTALLATION}" == "homebrew" ]; then - export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH - fi - - pip -q install numpy + export PATH=/usr/local/opt/python/libexec/bin:/usr/local/bin:$PATH # Install Anaconda if we need to if [ -n "${CAFFE2_USE_ANACONDA}" ]; then @@ -164,6 +159,8 @@ source ${TMPDIR}/anaconda/bin/activate fi + pip -q install numpy + # Install sccache sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache @@ -201,4 +198,3 @@ if which sccache > /dev/null; then sccache --show-stats fi - diff --git a/.circleci/verbatim-sources/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs-custom.yml index 7f089a760f35e..d0cb06a8ce092 100644 --- a/.circleci/verbatim-sources/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs-custom.yml @@ -429,7 +429,7 @@ export PATH="~/anaconda/bin:${PATH}" source ~/anaconda/bin/activate # Install dependencies - conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests + conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes # sync submodules cd ${PROJ_ROOT} git submodule sync diff --git a/.flake8 b/.flake8 index bb1ecb6fb4bd9..d5d0a4544676f 100644 --- a/.flake8 +++ b/.flake8 @@ -11,4 +11,4 @@ ignore = B007,B008, # these ignores are from flake8-comprehensions; please fix! C400,C401,C402,C403,C404,C405,C407,C411, -exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,tools/amd_build/pyHIPIFY,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi +exclude = docs/src,venv,third_party,caffe2,scripts,docs/caffe2,torch/lib/include,torch/lib/tmp_install,build,torch/include,*.pyi diff --git a/.jenkins/pytorch/common.sh b/.jenkins/pytorch/common.sh index 8d6764ec1ae09..d727dcd57272f 100644 --- a/.jenkins/pytorch/common.sh +++ b/.jenkins/pytorch/common.sh @@ -158,7 +158,9 @@ fi function pip_install() { # retry 3 times - pip install --progress-bar off "$@" || pip install --progress-bar off "$@" || pip install --progress-bar off "$@" + # old versions of pip don't have the "--progress-bar" flag + pip install --progress-bar off "$@" || pip install --progress-bar off "$@" || pip install --progress-bar off "$@" ||\ + pip install "$@" || pip install "$@" || pip install "$@" } function pip_uninstall() { diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index e5cec7aeb6e07..ec2b05e2c62e3 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -32,6 +32,12 @@ if [ -n "${IN_CIRCLECI}" ]; then fi fi +if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then + # TODO: Move this to Docker + sudo apt-get -qq update + sudo apt-get -qq install --no-install-recommends libsndfile1 +fi + # --user breaks ppc64le builds and these packages are already in ppc64le docker if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]]; then # JIT C++ extensions require ninja. diff --git a/CMakeLists.txt b/CMakeLists.txt index 909d6b914a26b..4fac53ec586ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -140,7 +140,7 @@ option(USE_METAL "Use Metal for iOS build" ON) option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option( USE_NCCL "Use NCCL" ON - "USE_CUDA;UNIX;NOT APPLE" OFF) + "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) cmake_dependent_option( USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1274d24c82af9..548f16f9c0210 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -218,6 +218,8 @@ pip install -r requirements.txt # `katex` must also be available in your PATH. # If you are using Ubuntu or Debian, you can install it with: # sudo apt install katex +# If you are using MacOS, you can install it through npm (install Node.js first): +# npm install -g katex ``` 3. Generate the documentation HTML files. The generated files will be in `docs/build/html`. diff --git a/README.md b/README.md index 0d99d8c73777c..8f8aa4ecf5962 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ You can write new neural network layers in Python using the torch API [or your favorite NumPy-based libraries such as SciPy](https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html). If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. -There is no wrapper code that needs to be written. You can see [a tutorial here](https://pytorch.org/tutorials/advanced/cpp_extension.html) and [an example here](https://github.com/pytorch/extension-cpp). +No wrapper code needs to be written. You can see [a tutorial here](https://pytorch.org/tutorials/advanced/cpp_extension.html) and [an example here](https://github.com/pytorch/extension-cpp). ## Installation @@ -145,7 +145,7 @@ Python wheels for NVIDIA's Jetson Nano, Jetson TX2, and Jetson AGX Xavier are av - Python 2.7: https://nvidia.box.com/v/torch-weekly-cp27-jetson-jp42 - Python 3.6: https://nvidia.box.com/v/torch-weekly-cp36-jetson-jp42 -They requires JetPack 4.2 and above and are maintained by @dusty-nv +They require JetPack 4.2 and above, and @dusty-nv maintains them ### From Source @@ -175,7 +175,7 @@ conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing On Linux ```bash # Add LAPACK support for the GPU if needed -conda install -c pytorch magma-cuda90 # or [magma-cuda92 | magma-cuda100 ] depending on your cuda version +conda install -c pytorch magma-cuda90 # or [magma-cuda92 | magma-cuda100 | magma-cuda101 ] depending on your cuda version ``` #### Get the PyTorch Source @@ -234,7 +234,7 @@ set FORCE_PY27_BUILD=1 :: Note: This value is useless if Ninja is detected. However, you can force that by using `set USE_NINJA=OFF`. set CMAKE_GENERATOR=Visual Studio 15 2017 -:: Read the content in the previous section carefully before you preceed. +:: Read the content in the previous section carefully before you proceed. :: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block. :: "Visual Studio 2017 Developer Command Prompt" will be run automatically. :: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator. diff --git a/android/gradle.properties b/android/gradle.properties index ec9e4008fa562..ff63986f2d6de 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -1,6 +1,6 @@ ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64 -VERSION_NAME=0.0.7-SNAPSHOT +VERSION_NAME=1.4.0-SNAPSHOT GROUP=org.pytorch MAVEN_GROUP=org.pytorch POM_URL=https://github.com/pytorch/pytorch/tree/master/android diff --git a/android/gradle/gradle_maven_push.gradle b/android/gradle/gradle_maven_push.gradle index c1660c7fddf90..5fdd8fbc6a037 100644 --- a/android/gradle/gradle_maven_push.gradle +++ b/android/gradle/gradle_maven_push.gradle @@ -25,6 +25,18 @@ def getRepositoryPassword() { return hasProperty('SONATYPE_NEXUS_PASSWORD') ? SONATYPE_NEXUS_PASSWORD : "" } +def getHttpProxyHost() { + return project.properties['systemProp.http.proxyHost'] +} + +def getHttpProxyPort() { + return project.properties['systemProp.http.proxyPort'] +} + +def needProxy() { + return (getHttpProxyHost() != null) && (getHttpProxyPort() != null) +} + afterEvaluate { project -> uploadArchives { repositories { @@ -37,9 +49,15 @@ afterEvaluate { project -> repository(url: getReleaseRepositoryUrl()) { authentication(userName: getRepositoryUsername(), password: getRepositoryPassword()) + if (needProxy()) { + proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http') + } } snapshotRepository(url: getSnapshotRepositoryUrl()) { authentication(userName: getRepositoryUsername(), password: getRepositoryPassword()) + if (needProxy()) { + proxy(host: getHttpProxyHost(), port: getHttpProxyPort() as Integer, type: 'http') + } } pom.project { diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchHostTests.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchHostTests.java new file mode 100644 index 0000000000000..47367180b6bc0 --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchHostTests.java @@ -0,0 +1,30 @@ +package org.pytorch; + +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.util.Objects; + +public class PytorchHostTests extends PytorchTestBase { + @BeforeClass + public static void setUpClass() { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); + } + } + + @Override + protected String assetFilePath(String assetName) throws IOException { + Path tempFile = Files.createTempFile("test", ".pt"); + try (InputStream resource = Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) { + Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING); + } + return tempFile.toAbsolutePath().toString(); + } +} diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java index 7eed05c4b1fb6..43a6c4f366777 100644 --- a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java @@ -2,8 +2,6 @@ import android.content.Context; -import org.junit.Before; -import org.junit.Test; import org.junit.runner.RunWith; import java.io.File; @@ -11,294 +9,15 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.util.HashMap; -import java.util.Map; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.platform.app.InstrumentationRegistry; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - @RunWith(AndroidJUnit4.class) -public class PytorchInstrumentedTests { - - private static final String TEST_MODULE_ASSET_NAME = "test.pt"; - - @Before - public void setUp() { - System.loadLibrary("pytorch"); - } - - @Test - public void testForwardNull() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - final IValue input = - IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1))); - assertTrue(input.isTensor()); - final IValue output = module.forward(input); - assertTrue(output.isNull()); - } - - @Test - public void testEqBool() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - for (boolean value : new boolean[] {false, true}) { - final IValue input = IValue.bool(value); - assertTrue(input.isBool()); - assertTrue(value == input.getBool()); - final IValue output = module.runMethod("eqBool", input); - assertTrue(output.isBool()); - assertTrue(value == output.getBool()); - } - } - - @Test - public void testEqInt() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) { - final IValue input = IValue.long64(value); - assertTrue(input.isLong()); - assertTrue(value == input.getLong()); - final IValue output = module.runMethod("eqInt", input); - assertTrue(output.isLong()); - assertTrue(value == output.getLong()); - } - } - - @Test - public void testEqFloat() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - double[] values = - new double[] { - -Double.MAX_VALUE, - Double.MAX_VALUE, - -Double.MIN_VALUE, - Double.MIN_VALUE, - -Math.exp(1.d), - -Math.sqrt(2.d), - -3.1415f, - 3.1415f, - -1, - 0, - 1, - }; - for (double value : values) { - final IValue input = IValue.double64(value); - assertTrue(input.isDouble()); - assertTrue(value == input.getDouble()); - final IValue output = module.runMethod("eqFloat", input); - assertTrue(output.isDouble()); - assertTrue(value == output.getDouble()); - } - } - - @Test - public void testEqTensor() throws IOException { - final long[] inputTensorShape = new long[] {1, 3, 224, 224}; - final long numElements = Tensor.numel(inputTensorShape); - final float[] inputTensorData = new float[(int) numElements]; - for (int i = 0; i < numElements; ++i) { - inputTensorData[i] = i; - } - final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData); - - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - final IValue input = IValue.tensor(inputTensor); - assertTrue(input.isTensor()); - assertTrue(inputTensor == input.getTensor()); - final IValue output = module.runMethod("eqTensor", input); - assertTrue(output.isTensor()); - final Tensor outputTensor = output.getTensor(); - assertNotNull(outputTensor); - assertArrayEquals(inputTensorShape, outputTensor.shape); - float[] outputData = outputTensor.getDataAsFloatArray(); - for (int i = 0; i < numElements; i++) { - assertTrue(inputTensorData[i] == outputData[i]); - } - } - - @Test - public void testEqDictIntKeyIntValue() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - final Map inputMap = new HashMap<>(); - - inputMap.put(Long.MIN_VALUE, IValue.long64(-Long.MIN_VALUE)); - inputMap.put(Long.MAX_VALUE, IValue.long64(-Long.MAX_VALUE)); - inputMap.put(0l, IValue.long64(0l)); - inputMap.put(1l, IValue.long64(-1l)); - inputMap.put(-1l, IValue.long64(1l)); - - final IValue input = IValue.dictLongKey(inputMap); - assertTrue(input.isDictLongKey()); - - final IValue output = module.runMethod("eqDictIntKeyIntValue", input); - assertTrue(output.isDictLongKey()); - - final Map outputMap = output.getDictLongKey(); - assertTrue(inputMap.size() == outputMap.size()); - for (Map.Entry entry : inputMap.entrySet()) { - assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong()); - } - } - - @Test - public void testEqDictStrKeyIntValue() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - final Map inputMap = new HashMap<>(); - - inputMap.put("long_min_value", IValue.long64(Long.MIN_VALUE)); - inputMap.put("long_max_value", IValue.long64(Long.MAX_VALUE)); - inputMap.put("long_0", IValue.long64(0l)); - inputMap.put("long_1", IValue.long64(1l)); - inputMap.put("long_-1", IValue.long64(-1l)); - - final IValue input = IValue.dictStringKey(inputMap); - assertTrue(input.isDictStringKey()); - - final IValue output = module.runMethod("eqDictStrKeyIntValue", input); - assertTrue(output.isDictStringKey()); - - final Map outputMap = output.getDictStringKey(); - assertTrue(inputMap.size() == outputMap.size()); - for (Map.Entry entry : inputMap.entrySet()) { - assertTrue(outputMap.get(entry.getKey()).getLong() == entry.getValue().getLong()); - } - } - - @Test - public void testListIntSumReturnTuple() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - - for (int n : new int[] {0, 1, 128}) { - long[] a = new long[n]; - long sum = 0; - for (int i = 0; i < n; i++) { - a[i] = i; - sum += a[i]; - } - final IValue input = IValue.longList(a); - assertTrue(input.isLongList()); - - final IValue output = module.runMethod("listIntSumReturnTuple", input); - - assertTrue(output.isTuple()); - assertTrue(2 == output.getTuple().length); - - IValue output0 = output.getTuple()[0]; - IValue output1 = output.getTuple()[1]; - - assertArrayEquals(a, output0.getLongList()); - assertTrue(sum == output1.getLong()); - } - } - - @Test - public void testOptionalIntIsNone() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - - assertFalse(module.runMethod("optionalIntIsNone", IValue.long64(1l)).getBool()); - assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).getBool()); - } - - @Test - public void testIntEq0None() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - - assertTrue(module.runMethod("intEq0None", IValue.long64(0l)).isNull()); - assertTrue(module.runMethod("intEq0None", IValue.long64(1l)).getLong() == 1l); - } - - @Test(expected = IllegalArgumentException.class) - public void testRunUndefinedMethod() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - module.runMethod("test_undefined_method_throws_exception"); - } - - @Test - public void testTensorMethods() { - long[] shape = new long[] {1, 3, 224, 224}; - final int numel = (int) Tensor.numel(shape); - int[] ints = new int[numel]; - float[] floats = new float[numel]; - - byte[] bytes = new byte[numel]; - for (int i = 0; i < numel; i++) { - bytes[i] = (byte) ((i % 255) - 128); - ints[i] = i; - floats[i] = i / 1000.f; - } - - Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes); - assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8); - assertArrayEquals(bytes, tensorBytes.getDataAsByteArray()); - - Tensor tensorInts = Tensor.newInt32Tensor(shape, ints); - assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32); - assertArrayEquals(ints, tensorInts.getDataAsIntArray()); - - Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats); - assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); - float[] floatsOut = tensorFloats.getDataAsFloatArray(); - assertTrue(floatsOut.length == numel); - for (int i = 0; i < numel; i++) { - assertTrue(floats[i] == floatsOut[i]); - } - } - - @Test(expected = IllegalStateException.class) - public void testTensorIllegalStateOnWrongType() { - long[] shape = new long[] {1, 3, 224, 224}; - final int numel = (int) Tensor.numel(shape); - float[] floats = new float[numel]; - Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats); - assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); - tensorFloats.getDataAsByteArray(); - } - - - @Test - public void testEqString() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - String[] values = - new String[] { - "smoketest", - "проверка не латинских символов", // not latin symbols check - "#@$!@#)($*!@#$)(!@*#$" - }; - for (String value : values) { - final IValue input = IValue.string(value); - assertTrue(input.isString()); - assertTrue(value.equals(input.getString())); - final IValue output = module.runMethod("eqStr", input); - assertTrue(output.isString()); - assertTrue(value.equals(output.getString())); - } - } - - @Test - public void testStr3Concat() throws IOException { - final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); - String[] values = - new String[] { - "smoketest", - "проверка не латинских символов", // not latin symbols check - "#@$!@#)($*!@#$)(!@*#$" - }; - for (String value : values) { - final IValue input = IValue.string(value); - assertTrue(input.isString()); - assertTrue(value.equals(input.getString())); - final IValue output = module.runMethod("str3Concat", input); - assertTrue(output.isString()); - String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString(); - assertTrue(expectedOutput.equals(output.getString())); - } - } +public class PytorchInstrumentedTests extends PytorchTestBase { - private static String assetFilePath(String assetName) throws IOException { + @Override + protected String assetFilePath(String assetName) throws IOException { final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); File file = new File(appContext.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java new file mode 100644 index 0000000000000..a1cca0adddda2 --- /dev/null +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchTestBase.java @@ -0,0 +1,290 @@ +package org.pytorch; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public abstract class PytorchTestBase { + private static final String TEST_MODULE_ASSET_NAME = "test.pt"; + + @Before + public void setUp() { + System.loadLibrary("pytorch"); + } + + @Test + public void testForwardNull() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + final IValue input = + IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1})); + assertTrue(input.isTensor()); + final IValue output = module.forward(input); + assertTrue(output.isNull()); + } + + @Test + public void testEqBool() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + for (boolean value : new boolean[] {false, true}) { + final IValue input = IValue.from(value); + assertTrue(input.isBool()); + assertTrue(value == input.toBool()); + final IValue output = module.runMethod("eqBool", input); + assertTrue(output.isBool()); + assertTrue(value == output.toBool()); + } + } + + @Test + public void testEqInt() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) { + final IValue input = IValue.from(value); + assertTrue(input.isLong()); + assertTrue(value == input.toLong()); + final IValue output = module.runMethod("eqInt", input); + assertTrue(output.isLong()); + assertTrue(value == output.toLong()); + } + } + + @Test + public void testEqFloat() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + double[] values = + new double[] { + -Double.MAX_VALUE, + Double.MAX_VALUE, + -Double.MIN_VALUE, + Double.MIN_VALUE, + -Math.exp(1.d), + -Math.sqrt(2.d), + -3.1415f, + 3.1415f, + -1, + 0, + 1, + }; + for (double value : values) { + final IValue input = IValue.from(value); + assertTrue(input.isDouble()); + assertTrue(value == input.toDouble()); + final IValue output = module.runMethod("eqFloat", input); + assertTrue(output.isDouble()); + assertTrue(value == output.toDouble()); + } + } + + @Test + public void testEqTensor() throws IOException { + final long[] inputTensorShape = new long[] {1, 3, 224, 224}; + final long numElements = Tensor.numel(inputTensorShape); + final float[] inputTensorData = new float[(int) numElements]; + for (int i = 0; i < numElements; ++i) { + inputTensorData[i] = i; + } + final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape); + + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + final IValue input = IValue.from(inputTensor); + assertTrue(input.isTensor()); + assertTrue(inputTensor == input.toTensor()); + final IValue output = module.runMethod("eqTensor", input); + assertTrue(output.isTensor()); + final Tensor outputTensor = output.toTensor(); + assertNotNull(outputTensor); + assertArrayEquals(inputTensorShape, outputTensor.shape()); + float[] outputData = outputTensor.getDataAsFloatArray(); + for (int i = 0; i < numElements; i++) { + assertTrue(inputTensorData[i] == outputData[i]); + } + } + + @Test + public void testEqDictIntKeyIntValue() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + final Map inputMap = new HashMap<>(); + + inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE)); + inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE)); + inputMap.put(0l, IValue.from(0l)); + inputMap.put(1l, IValue.from(-1l)); + inputMap.put(-1l, IValue.from(1l)); + + final IValue input = IValue.dictLongKeyFrom(inputMap); + assertTrue(input.isDictLongKey()); + + final IValue output = module.runMethod("eqDictIntKeyIntValue", input); + assertTrue(output.isDictLongKey()); + + final Map outputMap = output.toDictLongKey(); + assertTrue(inputMap.size() == outputMap.size()); + for (Map.Entry entry : inputMap.entrySet()) { + assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong()); + } + } + + @Test + public void testEqDictStrKeyIntValue() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + final Map inputMap = new HashMap<>(); + + inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE)); + inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE)); + inputMap.put("long_0", IValue.from(0l)); + inputMap.put("long_1", IValue.from(1l)); + inputMap.put("long_-1", IValue.from(-1l)); + + final IValue input = IValue.dictStringKeyFrom(inputMap); + assertTrue(input.isDictStringKey()); + + final IValue output = module.runMethod("eqDictStrKeyIntValue", input); + assertTrue(output.isDictStringKey()); + + final Map outputMap = output.toDictStringKey(); + assertTrue(inputMap.size() == outputMap.size()); + for (Map.Entry entry : inputMap.entrySet()) { + assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong()); + } + } + + @Test + public void testListIntSumReturnTuple() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + + for (int n : new int[] {0, 1, 128}) { + long[] a = new long[n]; + long sum = 0; + for (int i = 0; i < n; i++) { + a[i] = i; + sum += a[i]; + } + final IValue input = IValue.listFrom(a); + assertTrue(input.isLongList()); + + final IValue output = module.runMethod("listIntSumReturnTuple", input); + + assertTrue(output.isTuple()); + assertTrue(2 == output.toTuple().length); + + IValue output0 = output.toTuple()[0]; + IValue output1 = output.toTuple()[1]; + + assertArrayEquals(a, output0.toLongList()); + assertTrue(sum == output1.toLong()); + } + } + + @Test + public void testOptionalIntIsNone() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + + assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool()); + assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool()); + } + + @Test + public void testIntEq0None() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + + assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull()); + assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l); + } + + @Test(expected = IllegalArgumentException.class) + public void testRunUndefinedMethod() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + module.runMethod("test_undefined_method_throws_exception"); + } + + @Test + public void testTensorMethods() { + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); + int[] ints = new int[numel]; + float[] floats = new float[numel]; + + byte[] bytes = new byte[numel]; + for (int i = 0; i < numel; i++) { + bytes[i] = (byte) ((i % 255) - 128); + ints[i] = i; + floats[i] = i / 1000.f; + } + + Tensor tensorBytes = Tensor.fromBlob(bytes, shape); + assertTrue(tensorBytes.dtype() == DType.INT8); + assertArrayEquals(bytes, tensorBytes.getDataAsByteArray()); + + Tensor tensorInts = Tensor.fromBlob(ints, shape); + assertTrue(tensorInts.dtype() == DType.INT32); + assertArrayEquals(ints, tensorInts.getDataAsIntArray()); + + Tensor tensorFloats = Tensor.fromBlob(floats, shape); + assertTrue(tensorFloats.dtype() == DType.FLOAT32); + float[] floatsOut = tensorFloats.getDataAsFloatArray(); + assertTrue(floatsOut.length == numel); + for (int i = 0; i < numel; i++) { + assertTrue(floats[i] == floatsOut[i]); + } + } + + @Test(expected = IllegalStateException.class) + public void testTensorIllegalStateOnWrongType() { + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); + float[] floats = new float[numel]; + Tensor tensorFloats = Tensor.fromBlob(floats, shape); + assertTrue(tensorFloats.dtype() == DType.FLOAT32); + tensorFloats.getDataAsByteArray(); + } + + + @Test + public void testEqString() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + String[] values = + new String[] { + "smoketest", + "проверка не латинских символов", // not latin symbols check + "#@$!@#)($*!@#$)(!@*#$" + }; + for (String value : values) { + final IValue input = IValue.from(value); + assertTrue(input.isString()); + assertTrue(value.equals(input.toStr())); + final IValue output = module.runMethod("eqStr", input); + assertTrue(output.isString()); + assertTrue(value.equals(output.toStr())); + } + } + + @Test + public void testStr3Concat() throws IOException { + final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); + String[] values = + new String[] { + "smoketest", + "проверка не латинских символов", // not latin symbols check + "#@$!@#)($*!@#$)(!@*#$" + }; + for (String value : values) { + final IValue input = IValue.from(value); + assertTrue(input.isString()); + assertTrue(value.equals(input.toStr())); + final IValue output = module.runMethod("str3Concat", input); + assertTrue(output.isString()); + String expectedOutput = new StringBuilder().append(value).append(value).append(value).toString(); + assertTrue(expectedOutput.equals(output.toStr())); + } + } + + protected abstract String assetFilePath(String assetName) throws IOException; +} diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp index bb00e47087422..7f09b51a5151b 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp @@ -10,6 +10,8 @@ namespace pytorch_jni { +// NOTE: Codes must be kept in sync with DType.java. +// NOTE: Never serialize these, because they can change between releases. constexpr static int kTensorDTypeUInt8 = 1; constexpr static int kTensorDTypeInt8 = 2; constexpr static int kTensorDTypeInt32 = 3; @@ -164,7 +166,7 @@ class JTensor : public facebook::jni::JavaClass { static at::Tensor newAtTensorFromJTensor( facebook::jni::alias_ref jtensor) { static const auto dtypeMethod = - JTensor::javaClassStatic()->getMethod("dtype"); + JTensor::javaClassStatic()->getMethod("dtypeJniCode"); jint jdtype = dtypeMethod(jtensor); static const auto shapeField = @@ -216,7 +218,7 @@ class JIValue : public facebook::jni::JavaClass { static auto jMethodTensor = JIValue::javaClassStatic() ->getStaticMethod( - facebook::jni::local_ref)>("tensor"); + facebook::jni::local_ref)>("from"); return jMethodTensor( JIValue::javaClassStatic(), JTensor::newJTensorFromAtTensor(ivalue.toTensor())); @@ -224,26 +226,26 @@ class JIValue : public facebook::jni::JavaClass { static auto jMethodBool = JIValue::javaClassStatic() ->getStaticMethod(jboolean)>( - "bool"); + "from"); return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool()); } else if (ivalue.isInt()) { static auto jMethodInt = JIValue::javaClassStatic() ->getStaticMethod(jlong)>( - "long64"); + "from"); return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt()); } else if (ivalue.isDouble()) { static auto jMethodDouble = JIValue::javaClassStatic() ->getStaticMethod(jdouble)>( - "double64"); + "from"); return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble()); } else if (ivalue.isString()) { static auto jMethodString = JIValue::javaClassStatic() ->getStaticMethod( facebook::jni::alias_ref< - facebook::jni::JString::javaobject>)>("string"); + facebook::jni::JString::javaobject>)>("from"); return jMethodString( JIValue::javaClassStatic(), facebook::jni::make_jstring(ivalue.toStringRef())); @@ -253,7 +255,7 @@ class JIValue : public facebook::jni::JavaClass { JIValue::javaClassStatic() ->getStaticMethod( facebook::jni::alias_ref::javaobject>)>("tuple"); + JIValue::javaobject>::javaobject>)>("tupleFrom"); auto jElementsArray = facebook::jni::JArrayClass::newArray( elementsVec.size()); @@ -267,7 +269,7 @@ class JIValue : public facebook::jni::JavaClass { static auto jMethodBoolListArr = JIValue::javaClassStatic() ->getStaticMethod( - facebook::jni::alias_ref)>("boolList"); + facebook::jni::alias_ref)>("listFrom"); size_t n = list.size(); auto jArray = facebook::jni::make_boolean_array(n); auto jArrayPinned = jArray->pin(); @@ -281,7 +283,7 @@ class JIValue : public facebook::jni::JavaClass { static auto jMethodLongListArr = JIValue::javaClassStatic() ->getStaticMethod( - facebook::jni::alias_ref)>("longList"); + facebook::jni::alias_ref)>("listFrom"); size_t n = list.size(); auto jArray = facebook::jni::make_long_array(n); auto jArrayPinned = jArray->pin(); @@ -295,7 +297,7 @@ class JIValue : public facebook::jni::JavaClass { static auto jMethoDoubleListArr = JIValue::javaClassStatic() ->getStaticMethod( - facebook::jni::alias_ref)>("doubleList"); + facebook::jni::alias_ref)>("listFrom"); size_t n = list.size(); auto jArray = facebook::jni::make_double_array(n); auto jArrayPinned = jArray->pin(); @@ -310,7 +312,7 @@ class JIValue : public facebook::jni::JavaClass { JIValue::javaClassStatic() ->getStaticMethod( facebook::jni::alias_ref::javaobject>)>("tensorList"); + JTensor::javaobject>::javaobject>)>("listFrom"); auto jArray = facebook::jni::JArrayClass::newArray( list.size()); auto index = 0; @@ -324,7 +326,7 @@ class JIValue : public facebook::jni::JavaClass { JIValue::javaClassStatic() ->getStaticMethod( facebook::jni::alias_ref::javaobject>)>("list"); + JIValue::javaobject>::javaobject>)>("listFrom"); auto jArray = facebook::jni::JArrayClass::newArray( list.size()); auto index = 0; @@ -351,7 +353,7 @@ class JIValue : public facebook::jni::JavaClass { facebook::jni::alias_ref< facebook::jni::JString::javaobject>, facebook::jni::alias_ref>>)>( - "dictStringKey"); + "dictStringKeyFrom"); auto jmap = JHashMap< facebook::jni::alias_ref, @@ -370,7 +372,7 @@ class JIValue : public facebook::jni::JavaClass { facebook::jni::alias_ref< facebook::jni::JLong::javaobject>, facebook::jni::alias_ref>>)>( - "dictLongKey"); + "dictLongKeyFrom"); auto jmap = JHashMap< facebook::jni::alias_ref, facebook::jni::alias_ref>::create(); @@ -404,32 +406,32 @@ class JIValue : public facebook::jni::JavaClass { static const auto jMethodGetTensor = JIValue::javaClassStatic() ->getMethod()>( - "getTensor"); + "toTensor"); return JTensor::newAtTensorFromJTensor(jMethodGetTensor(jivalue)); } else if (JIValue::kTypeCodeBool == typeCode) { static const auto jMethodGetBool = - JIValue::javaClassStatic()->getMethod("getBool"); + JIValue::javaClassStatic()->getMethod("toBool"); // explicit cast to bool as jboolean is defined as uint8_t, IValue ctor // for int will be called for jboolean bool b = jMethodGetBool(jivalue); return at::IValue{b}; } else if (JIValue::kTypeCodeLong == typeCode) { static const auto jMethodGetLong = - JIValue::javaClassStatic()->getMethod("getLong"); + JIValue::javaClassStatic()->getMethod("toLong"); return at::IValue{jMethodGetLong(jivalue)}; } else if (JIValue::kTypeCodeDouble == typeCode) { static const auto jMethodGetDouble = - JIValue::javaClassStatic()->getMethod("getDouble"); + JIValue::javaClassStatic()->getMethod("toDouble"); return at::IValue{jMethodGetDouble(jivalue)}; } else if (JIValue::kTypeCodeString == typeCode) { static const auto jMethodGetString = - JIValue::javaClassStatic()->getMethod("getString"); + JIValue::javaClassStatic()->getMethod("toStr"); return at::IValue{jMethodGetString(jivalue)->toStdString()}; } else if (JIValue::kTypeCodeTuple == typeCode) { static const auto jMethodGetTuple = JIValue::javaClassStatic() ->getMethod::javaobject()>("getTuple"); + JIValue::javaobject>::javaobject()>("toTuple"); auto jarray = jMethodGetTuple(jivalue); size_t n = jarray->size(); @@ -443,7 +445,7 @@ class JIValue : public facebook::jni::JavaClass { return c10::ivalue::Tuple::create(std::move(elements)); } else if (JIValue::kTypeCodeBoolList == typeCode) { static const auto jMethodGetBoolList = - JIValue::javaClassStatic()->getMethod("getBoolList"); + JIValue::javaClassStatic()->getMethod("toBoolList"); auto jArray = jMethodGetBoolList(jivalue); auto jArrayPinned = jArray->pin(); size_t n = jArrayPinned.size(); @@ -455,7 +457,7 @@ class JIValue : public facebook::jni::JavaClass { return at::IValue{std::move(list)}; } else if (JIValue::kTypeCodeLongList == typeCode) { static const auto jMethodGetLongList = - JIValue::javaClassStatic()->getMethod("getLongList"); + JIValue::javaClassStatic()->getMethod("toLongList"); auto jArray = jMethodGetLongList(jivalue); auto jArrayPinned = jArray->pin(); size_t n = jArrayPinned.size(); @@ -468,7 +470,7 @@ class JIValue : public facebook::jni::JavaClass { } else if (JIValue::kTypeCodeDoubleList == typeCode) { static const auto jMethodGetDoubleList = JIValue::javaClassStatic()->getMethod( - "getDoubleList"); + "toDoubleList"); auto jArray = jMethodGetDoubleList(jivalue); auto jArrayPinned = jArray->pin(); size_t n = jArrayPinned.size(); @@ -482,7 +484,7 @@ class JIValue : public facebook::jni::JavaClass { static const auto jMethodGetTensorList = JIValue::javaClassStatic() ->getMethod::javaobject()>("getTensorList"); + JTensor::javaobject>::javaobject()>("toTensorList"); auto jArray = jMethodGetTensorList(jivalue); size_t n = jArray->size(); c10::List list{}; @@ -495,7 +497,7 @@ class JIValue : public facebook::jni::JavaClass { static const auto jMethodGetList = JIValue::javaClassStatic() ->getMethod::javaobject()>("getList"); + JIValue::javaobject>::javaobject()>("toList"); auto jarray = jMethodGetList(jivalue); size_t n = jarray->size(); if (n == 0) { @@ -518,7 +520,7 @@ class JIValue : public facebook::jni::JavaClass { static const auto jMethodGetDictStringKey = JIValue::javaClassStatic() ->getMethod:: - javaobject()>("getDictStringKey"); + javaobject()>("toDictStringKey"); auto jmap = jMethodGetDictStringKey(jivalue); auto it = jmap->begin(); if (it == jmap->end()) { @@ -541,7 +543,7 @@ class JIValue : public facebook::jni::JavaClass { JIValue::javaClassStatic() ->getMethod::javaobject()>("getDictLongKey"); + JIValue::javaobject>::javaobject()>("toDictLongKey"); auto jmap = jMethodGetDictLongKey(jivalue); auto it = jmap->begin(); if (it == jmap->end()) { diff --git a/android/pytorch_android/src/main/java/org/pytorch/DType.java b/android/pytorch_android/src/main/java/org/pytorch/DType.java new file mode 100644 index 0000000000000..2278b16e0bb96 --- /dev/null +++ b/android/pytorch_android/src/main/java/org/pytorch/DType.java @@ -0,0 +1,29 @@ +package org.pytorch; + +/** + * Codes representing tensor data types. + */ +public enum DType { + // NOTE: "jniCode" must be kept in sync with pytorch_jni.cpp. + // NOTE: Never serialize "jniCode", because it can change between releases. + + /** Code for dtype torch.uint8. {@link Tensor#dtype()} */ + UINT8(1), + /** Code for dtype torch.int8. {@link Tensor#dtype()} */ + INT8(2), + /** Code for dtype torch.int32. {@link Tensor#dtype()} */ + INT32(3), + /** Code for dtype torch.float32. {@link Tensor#dtype()} */ + FLOAT32(4), + /** Code for dtype torch.int64. {@link Tensor#dtype()} */ + INT64(5), + /** Code for dtype torch.float64. {@link Tensor#dtype()} */ + FLOAT64(6), + ; + + final int jniCode; + + DType(int jniCode) { + this.jniCode = jniCode; + } +} diff --git a/android/pytorch_android/src/main/java/org/pytorch/IValue.java b/android/pytorch_android/src/main/java/org/pytorch/IValue.java index 5868721be543a..99007cad3469f 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/IValue.java +++ b/android/pytorch_android/src/main/java/org/pytorch/IValue.java @@ -4,10 +4,21 @@ import java.util.Map; /** - * Java representation of a torchscript variable, which is implemented as tagged union that can be - * one of the supported types: https://pytorch.org/docs/stable/jit.html#types. + * Java representation of a TorchScript value, which is implemented as tagged union that can be + * one of the supported types: https://pytorch.org/docs/stable/jit.html#types . *

- * Calling getters for inappropriate types will throw IllegalStateException. + * Calling {@code toX} methods for inappropriate types will throw {@link IllegalStateException}. + *

+ * {@code IValue} objects are constructed with {@code IValue.from(value)}, + * {@code IValue.tupleFrom(value1, value2, ...)}, {@code IValue.listFrom(value1, value2, ...)}, + * or one of the {@code dict} methods, depending on the key type. + *

+ * Data is retrieved from {@code IValue} objects with the {@code toX()} methods. Note that + * {@code str}-type IValues must be extracted with {@link #toStr()}, + * rather than {@link #toString()}. + *

+ * {@code IValue} objects may retain references to objects passed into their constructors, + * and may return references to their internal state from {@code toX()}. */ public class IValue { private static final int TYPE_CODE_NULL = 1; @@ -91,95 +102,98 @@ public boolean isDictLongKey() { return TYPE_CODE_DICT_LONG_KEY == this.mTypeCode; } + /** + * Creates a new {@code IValue} of type {@code Optional} that contains no value. + */ public static IValue optionalNull() { return new IValue(TYPE_CODE_NULL); } /** - * Creates a new IValue instance of torchscript Tensor type. + * Creates a new {@code IValue} of type {@code Tensor}. */ - public static IValue tensor(Tensor tensor) { + public static IValue from(Tensor tensor) { final IValue iv = new IValue(TYPE_CODE_TENSOR); iv.mData = tensor; return iv; } /** - * Creates a new IValue instance of torchscript bool type. + * Creates a new {@code IValue} of type {@code bool}. */ - public static IValue bool(boolean value) { + public static IValue from(boolean value) { final IValue iv = new IValue(TYPE_CODE_BOOL); iv.mData = value; return iv; } /** - * Creates a new IValue instance of torchscript int type. + * Creates a new {@code IValue} of type {@code int}. */ - public static IValue long64(long value) { + public static IValue from(long value) { final IValue iv = new IValue(TYPE_CODE_LONG); iv.mData = value; return iv; } /** - * Creates a new IValue instance of torchscript float type. + * Creates a new {@code IValue} of type {@code float}. */ - public static IValue double64(double value) { + public static IValue from(double value) { final IValue iv = new IValue(TYPE_CODE_DOUBLE); iv.mData = value; return iv; } /** - * Creates new IValue instance of torchscript str type. + * Creates a new {@code IValue} of type {@code str}. */ - public static IValue string(String value) { + public static IValue from(String value) { final IValue iv = new IValue(TYPE_CODE_STRING); iv.mData = value; return iv; } /** - * Creates a new IValue instance of torchscript List[bool] type. + * Creates a new {@code IValue} of type {@code List[bool]}. */ - public static IValue boolList(boolean... list) { + public static IValue listFrom(boolean... list) { final IValue iv = new IValue(TYPE_CODE_BOOL_LIST); iv.mData = list; return iv; } /** - * Creates a new IValue instance of torchscript List[int] type. + * Creates a new {@code IValue} of type {@code List[int]}. */ - public static IValue longList(long... list) { + public static IValue listFrom(long... list) { final IValue iv = new IValue(TYPE_CODE_LONG_LIST); iv.mData = list; return iv; } /** - * Creates a new IValue instance of torchscript List[float] type. + * Creates a new {@code IValue} of type {@code List[float]}. */ - public static IValue doubleList(double... list) { + public static IValue listFrom(double... list) { final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST); iv.mData = list; return iv; } /** - * Creates a new IValue instance of torchscript List[Tensor] type. + * Creates a new {@code IValue} of type {@code List[Tensor]}. */ - public static IValue tensorList(Tensor... list) { + public static IValue listFrom(Tensor... list) { final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST); iv.mData = list; return iv; } /** - * Creates a new IValue instance of torchscript List[T] type. All elements must have the same type. + * Creates a new {@code IValue} of type {@code List[T]}. All elements must have the same type. */ - public static IValue list(IValue... array) { + public static IValue listFrom(IValue... array) { final int size = array.length; if (size > 0) { final int typeCode0 = array[0].mTypeCode; @@ -196,93 +210,93 @@ public static IValue list(IValue... array) { } /** - * Creates a new IValue instance of torchscript Tuple[T0, T1, ...] type. + * Creates a new {@code IValue} of type {@code Tuple[T0, T1, ...]}. */ - public static IValue tuple(IValue... array) { + public static IValue tupleFrom(IValue... array) { final IValue iv = new IValue(TYPE_CODE_TUPLE); iv.mData = array; return iv; } /** - * Creates a new IValue instance oftorchscript Dict[Str, V] type. + * Creates a new {@code IValue} of type {@code Dict[str, V]}. */ - public static IValue dictStringKey(Map map) { + public static IValue dictStringKeyFrom(Map map) { final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY); iv.mData = map; return iv; } /** - * Creates a new IValue instance of torchscript Dict[int, V] type. + * Creates a new {@code IValue} of type {@code Dict[int, V]}. */ - public static IValue dictLongKey(Map map) { + public static IValue dictLongKeyFrom(Map map) { final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY); iv.mData = map; return iv; } - public Tensor getTensor() { + public Tensor toTensor() { preconditionType(TYPE_CODE_TENSOR, mTypeCode); return (Tensor) mData; } - public boolean getBool() { + public boolean toBool() { preconditionType(TYPE_CODE_BOOL, mTypeCode); return (boolean) mData; } - public long getLong() { + public long toLong() { preconditionType(TYPE_CODE_LONG, mTypeCode); return (long) mData; } - public double getDouble() { + public double toDouble() { preconditionType(TYPE_CODE_DOUBLE, mTypeCode); return (double) mData; } - public String getString() { + public String toStr() { preconditionType(TYPE_CODE_STRING, mTypeCode); return (String) mData; } - public boolean[] getBoolList() { + public boolean[] toBoolList() { preconditionType(TYPE_CODE_BOOL_LIST, mTypeCode); return (boolean[]) mData; } - public long[] getLongList() { + public long[] toLongList() { preconditionType(TYPE_CODE_LONG_LIST, mTypeCode); return (long[]) mData; } - public double[] getDoubleList() { + public double[] toDoubleList() { preconditionType(TYPE_CODE_DOUBLE_LIST, mTypeCode); return (double[]) mData; } - public Tensor[] getTensorList() { + public Tensor[] toTensorList() { preconditionType(TYPE_CODE_TENSOR_LIST, mTypeCode); return (Tensor[]) mData; } - public IValue[] getList() { + public IValue[] toList() { preconditionType(TYPE_CODE_LIST, mTypeCode); return (IValue[]) mData; } - public IValue[] getTuple() { + public IValue[] toTuple() { preconditionType(TYPE_CODE_TUPLE, mTypeCode); return (IValue[]) mData; } - public Map getDictStringKey() { + public Map toDictStringKey() { preconditionType(TYPE_CODE_DICT_STRING_KEY, mTypeCode); return (Map) mData; } - public Map getDictLongKey() { + public Map toDictLongKey() { preconditionType(TYPE_CODE_DICT_LONG_KEY, mTypeCode); return (Map) mData; } diff --git a/android/pytorch_android/src/main/java/org/pytorch/Module.java b/android/pytorch_android/src/main/java/org/pytorch/Module.java index 04147d3c26960..4ca47a4491af7 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/Module.java +++ b/android/pytorch_android/src/main/java/org/pytorch/Module.java @@ -5,21 +5,20 @@ import com.facebook.jni.HybridData; /** - * Java holder for torch::jit::script::Module which owns it on jni side. + * Java wrapper for torch::jit::script::Module. */ public class Module { private NativePeer mNativePeer; /** - * Loads serialized torchscript module from the specified absolute path on the disk. + * Loads a serialized TorchScript module from the specified path on the disk. * - * @param modelAbsolutePath absolute path to file that contains the serialized torchscript module. - * @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module on jni - * side. + * @param modelPath path to file that contains the serialized TorchScript module. + * @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module. */ - public static Module load(final String modelAbsolutePath) { - return new Module(modelAbsolutePath); + public static Module load(final String modelPath) { + return new Module(modelPath); } private Module(final String moduleAbsolutePath) { @@ -27,35 +26,32 @@ private Module(final String moduleAbsolutePath) { } /** - * Runs 'forward' method of loaded torchscript module with specified arguments. + * Runs the 'forward' method of this module with the specified arguments. * - * @param inputs arguments for torchscript module 'forward' method. - * @return result of torchscript module 'forward' method evaluation + * @param inputs arguments for the TorchScript module's 'forward' method. + * @return return value from the 'forward' method. */ public IValue forward(IValue... inputs) { return mNativePeer.forward(inputs); } /** - * Runs specified method of loaded torchscript module with specified arguments. + * Runs the specified method of this module with the specified arguments. * - * @param methodName torchscript module method to run - * @param inputs arguments that will be specified to torchscript module method call - * @return result of torchscript module specified method evaluation + * @param methodName name of the TorchScript method to run. + * @param inputs arguments that will be passed to TorchScript method. + * @return return value from the method. */ public IValue runMethod(String methodName, IValue... inputs) { return mNativePeer.runMethod(methodName, inputs); } /** - * Explicitly destructs native part. Current instance can not be used after this call. This - * method may be called multiple times safely. As fbjni library destructs native part - * automatically when current instance will be - * collected by Java GC, the instance will not leak if this method is not called, - * but timing of deletion and the thread will be at the whim of the Java GC. - * If you want to control the thread and timing of the destructor, you should call this method - * explicitly. - * {@link com.facebook.jni.HybridData#resetNative} + * Explicitly destroys the native torch::jit::script::Module. + * Calling this method is not required, as the native object will be destroyed + * when this object is garbage-collected. However, the timing of garbage collection + * is not guaranteed, so proactively calling {@code destroy} can free memory more quickly. + * See {@link com.facebook.jni.HybridData#resetNative}. */ public void destroy() { mNativePeer.mHybridData.resetNative(); diff --git a/android/pytorch_android/src/main/java/org/pytorch/Tensor.java b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java index 4178a9adfb098..4fdcf448a6855 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/Tensor.java +++ b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java @@ -11,24 +11,24 @@ import java.util.Locale; /** - * Representation of Tensor. Tensor shape is stored in {@link Tensor#shape}, elements are stored as - * {@link java.nio.DirectByteBuffer} of one of the supported types. + * Representation of a Tensor. Behavior is similar to PyTorch's tensor objects. + *

+ * Most tensors will be constructed as {@code Tensor.fromBlob(data, shape)}, + * where {@code data} can be an array or a direct {@link Buffer} (of the proper subclass). + * Helper methods are provided to allocate buffers properly. + *

+ * To access Tensor data, see {@link #dtype()}, {@link #shape()}, + * and various {@code getDataAs*} methods. + *

+ * When constructing {@code Tensor} objects with {@code data} as an array, + * it is not specified whether this data is is copied or retained as a reference + * so it is recommended not to modify it after constructing. {@code data} passed as a + * {@link Buffer} is not copied, so it can be modified between {@link Module} calls + * to avoid reallocation. Data retrieved from {@code Tensor} objects may be copied or + * may be a reference to the {@code Tensor}'s internal data buffer. + * {@code shape} is always copied. */ public abstract class Tensor { - - /** Code for dtype torch.uint8. {@link Tensor#dtype()} */ - public static final int DTYPE_UINT8 = 1; - /** Code for dtype torch.int8. {@link Tensor#dtype()} */ - public static final int DTYPE_INT8 = 2; - /** Code for dtype torch.int32. {@link Tensor#dtype()} */ - public static final int DTYPE_INT32 = 3; - /** Code for dtype torch.float32. {@link Tensor#dtype()} */ - public static final int DTYPE_FLOAT32 = 4; - /** Code for dtype torch.int64. {@link Tensor#dtype()} */ - public static final int DTYPE_INT64 = 5; - /** Code for dtype torch.float64. {@link Tensor#dtype()} */ - public static final int DTYPE_FLOAT64 = 6; - private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null"; @@ -39,8 +39,7 @@ public abstract class Tensor { private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; - /** Shape of current tensor. */ - public final long[] shape; + final long[] shape; private static final int INT_SIZE_BYTES = 4; private static final int FLOAT_SIZE_BYTES = 4; @@ -49,8 +48,8 @@ public abstract class Tensor { /** * Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified - * capacity that can be used in {@link Tensor#newInt8Tensor(long[], ByteBuffer)}, {@link - * Tensor#newUInt8Tensor(long[], ByteBuffer)}. + * capacity that can be used in {@link Tensor#fromBlob(ByteBuffer, long[])}, + * {@link Tensor#fromBlobUnsigned(ByteBuffer, long[])}. * * @param numElements capacity (number of elements) of result buffer. */ @@ -58,6 +57,12 @@ public static ByteBuffer allocateByteBuffer(int numElements) { return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); } + /** + * Allocates a new direct {@link java.nio.IntBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#fromBlob(IntBuffer, long[])}. + * + * @param numElements capacity (number of elements) of result buffer. + */ public static IntBuffer allocateIntBuffer(int numElements) { return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) .order(ByteOrder.nativeOrder()) @@ -66,7 +71,7 @@ public static IntBuffer allocateIntBuffer(int numElements) { /** * Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified - * capacity that can be used in {@link Tensor#newFloat32Tensor(long[], FloatBuffer)}. + * capacity that can be used in {@link Tensor#fromBlob(FloatBuffer, long[])}. * * @param numElements capacity (number of elements) of result buffer. */ @@ -78,7 +83,7 @@ public static FloatBuffer allocateFloatBuffer(int numElements) { /** * Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified - * capacity that can be used in {@link Tensor#newInt64Tensor(long[], LongBuffer)}. + * capacity that can be used in {@link Tensor#fromBlob(LongBuffer, long[])}. * * @param numElements capacity (number of elements) of result buffer. */ @@ -90,7 +95,7 @@ public static LongBuffer allocateLongBuffer(int numElements) { /** * Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified - * capacity that can be used in {@link Tensor#newFloat64Tensor(long[], DoubleBuffer)}. + * capacity that can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}. * * @param numElements capacity (number of elements) of result buffer. */ @@ -104,10 +109,10 @@ public static DoubleBuffer allocateDoubleBuffer(int numElements) { * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of * bytes. * - * @param shape Tensor shape * @param data Tensor elements + * @param shape Tensor shape */ - public static Tensor newUInt8Tensor(long[] shape, byte[] data) { + public static Tensor fromBlobUnsigned(byte[] data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -121,10 +126,10 @@ public static Tensor newUInt8Tensor(long[] shape, byte[] data) { * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of * bytes. * - * @param shape Tensor shape * @param data Tensor elements + * @param shape Tensor shape */ - public static Tensor newInt8Tensor(long[] shape, byte[] data) { + public static Tensor fromBlob(byte[] data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -138,10 +143,10 @@ public static Tensor newInt8Tensor(long[] shape, byte[] data) { * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of * ints. * - * @param shape Tensor shape * @param data Tensor elements + * @param shape Tensor shape */ - public static Tensor newInt32Tensor(long[] shape, int[] data) { + public static Tensor fromBlob(int[] data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -155,10 +160,10 @@ public static Tensor newInt32Tensor(long[] shape, int[] data) { * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array * of floats. * - * @param shape Tensor shape * @param data Tensor elements + * @param shape Tensor shape */ - public static Tensor newFloat32Tensor(long[] shape, float[] data) { + public static Tensor fromBlob(float[] data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -172,10 +177,10 @@ public static Tensor newFloat32Tensor(long[] shape, float[] data) { * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of * longs. * - * @param shape Tensor shape * @param data Tensor elements + * @param shape Tensor shape */ - public static Tensor newInt64Tensor(long[] shape, long[] data) { + public static Tensor fromBlob(long[] data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -192,7 +197,7 @@ public static Tensor newInt64Tensor(long[] shape, long[] data) { * @param shape Tensor shape * @param data Tensor elements */ - public static Tensor newFloat64Tensor(long[] shape, double[] data) { + public static Tensor fromBlob(long[] shape, double[] data) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -205,12 +210,12 @@ public static Tensor newFloat64Tensor(long[] shape, double[] data) { /** * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. * - * @param shape Tensor shape - * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} * elements. The buffer is used directly without copying, and changes to its content will * change the tensor. + * @param shape Tensor shape */ - public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) { + public static Tensor fromBlobUnsigned(ByteBuffer data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -225,12 +230,12 @@ public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) { /** * Creates a new Tensor instance with dtype torch.int8 with specified shape and data. * - * @param shape Tensor shape - * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} * elements. The buffer is used directly without copying, and changes to its content will * change the tensor. + * @param shape Tensor shape */ - public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) { + public static Tensor fromBlob(ByteBuffer data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -245,12 +250,12 @@ public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) { /** * Creates a new Tensor instance with dtype torch.int32 with specified shape and data. * - * @param shape Tensor shape - * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} * elements. The buffer is used directly without copying, and changes to its content will * change the tensor. + * @param shape Tensor shape */ - public static Tensor newInt32Tensor(long[] shape, IntBuffer data) { + public static Tensor fromBlob(IntBuffer data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -265,12 +270,12 @@ public static Tensor newInt32Tensor(long[] shape, IntBuffer data) { /** * Creates a new Tensor instance with dtype torch.float32 with specified shape and data. * - * @param shape Tensor shape - * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} * elements. The buffer is used directly without copying, and changes to its content will * change the tensor. + * @param shape Tensor shape */ - public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) { + public static Tensor fromBlob(FloatBuffer data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -285,12 +290,12 @@ public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) { /** * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. * - * @param shape Tensor shape - * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} * elements. The buffer is used directly without copying, and changes to its content will * change the tensor. + * @param shape Tensor shape */ - public static Tensor newInt64Tensor(long[] shape, LongBuffer data) { + public static Tensor fromBlob(LongBuffer data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -305,12 +310,12 @@ public static Tensor newInt64Tensor(long[] shape, LongBuffer data) { /** * Creates a new Tensor instance with dtype torch.float64 with specified shape and data. * - * @param shape Tensor shape - * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} * elements. The buffer is used directly without copying, and changes to its content will * change the tensor. + * @param shape Tensor shape */ - public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) { + public static Tensor fromBlob(DoubleBuffer data, long[] shape) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); checkShape(shape); @@ -327,12 +332,14 @@ private Tensor(long[] shape) { this.shape = Arrays.copyOf(shape, shape.length); } - /** Calculates number of elements in current tensor instance. */ + /** Returns the number of elements in this tensor. */ public long numel() { return numel(this.shape); } - /** Calculates number of elements in tensor with specified shape. */ + /** + * Calculates the number of elements in a tensor with the specified shape. + */ public static long numel(long[] shape) { checkShape(shape); int result = 1; @@ -343,14 +350,24 @@ public static long numel(long[] shape) { } /** - * Returns dtype of current tensor. Can be one of {@link Tensor#DTYPE_UINT8}, {@link - * Tensor#DTYPE_INT8}, {@link Tensor#DTYPE_INT32},{@link Tensor#DTYPE_FLOAT32}, {@link - * Tensor#DTYPE_INT64}, {@link Tensor#DTYPE_FLOAT64}. + * Returns the shape of this tensor. (The array is a fresh copy.) */ - public abstract int dtype(); + public long[] shape() { + return Arrays.copyOf(shape, shape.length); + } + + /** + * @return data type of this tensor. + */ + public abstract DType dtype(); + + // Called from native + int dtypeJniCode() { + return dtype().jniCode; + } /** - * Returns newly allocated java byte array that contains a copy of tensor data. + * @return a Java byte array that contains the tensor data. This may be a copy or reference. * * @throws IllegalStateException if it is called for a non-int8 tensor. */ @@ -360,7 +377,7 @@ public byte[] getDataAsByteArray() { } /** - * Returns newly allocated java byte array that contains a copy of tensor data. + * @return a Java byte array that contains the tensor data. This may be a copy or reference. * * @throws IllegalStateException if it is called for a non-uint8 tensor. */ @@ -370,7 +387,7 @@ public byte[] getDataAsUnsignedByteArray() { } /** - * Returns newly allocated java byte array that contains a copy of tensor data. + * @return a Java int array that contains the tensor data. This may be a copy or reference. * * @throws IllegalStateException if it is called for a non-int32 tensor. */ @@ -380,7 +397,7 @@ public int[] getDataAsIntArray() { } /** - * Returns newly allocated java byte array that contains a copy of tensor data. + * @return a Java float array that contains the tensor data. This may be a copy or reference. * * @throws IllegalStateException if it is called for a non-float32 tensor. */ @@ -390,7 +407,7 @@ public float[] getDataAsFloatArray() { } /** - * Returns newly allocated java byte array that contains a copy of tensor data. + * @return a Java long array that contains the tensor data. This may be a copy or reference. * * @throws IllegalStateException if it is called for a non-int64 tensor. */ @@ -400,7 +417,7 @@ public long[] getDataAsLongArray() { } /** - * Returns newly allocated java byte array that contains a copy of tensor data. + * @return a Java double array that contains the tensor data. This may be a copy or reference. * * @throws IllegalStateException if it is called for a non-float64 tensor. */ @@ -423,8 +440,8 @@ private Tensor_uint8(ByteBuffer data, long[] shape) { } @Override - public int dtype() { - return DTYPE_UINT8; + public DType dtype() { + return DType.UINT8; } @Override @@ -455,8 +472,8 @@ private Tensor_int8(ByteBuffer data, long[] shape) { } @Override - public int dtype() { - return DTYPE_INT8; + public DType dtype() { + return DType.INT8; } @Override @@ -487,8 +504,8 @@ private Tensor_int32(IntBuffer data, long[] shape) { } @Override - public int dtype() { - return DTYPE_INT32; + public DType dtype() { + return DType.INT32; } @Override @@ -527,8 +544,8 @@ public float[] getDataAsFloatArray() { } @Override - public int dtype() { - return DTYPE_FLOAT32; + public DType dtype() { + return DType.FLOAT32; } @Override @@ -551,8 +568,8 @@ private Tensor_int64(LongBuffer data, long[] shape) { } @Override - public int dtype() { - return DTYPE_INT64; + public DType dtype() { + return DType.INT64; } @Override @@ -583,8 +600,8 @@ private Tensor_float64(DoubleBuffer data, long[] shape) { } @Override - public int dtype() { - return DTYPE_FLOAT64; + public DType dtype() { + return DType.FLOAT64; } @Override @@ -634,17 +651,17 @@ private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[ // Called from native private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) { - if (DTYPE_FLOAT32 == dtype) { + if (DType.FLOAT32.jniCode == dtype) { return new Tensor_float32(data.asFloatBuffer(), shape); - } else if (DTYPE_INT32 == dtype) { + } else if (DType.INT32.jniCode == dtype) { return new Tensor_int32(data.asIntBuffer(), shape); - } else if (DTYPE_INT64 == dtype) { + } else if (DType.INT64.jniCode == dtype) { return new Tensor_int64(data.asLongBuffer(), shape); - } else if (DTYPE_FLOAT64 == dtype) { + } else if (DType.FLOAT64.jniCode == dtype) { return new Tensor_float64(data.asDoubleBuffer(), shape); - } else if (DTYPE_UINT8 == dtype) { + } else if (DType.UINT8.jniCode == dtype) { return new Tensor_uint8(data, shape); - } else if (DTYPE_INT8 == dtype) { + } else if (DType.INT8.jniCode == dtype) { return new Tensor_int8(data, shape); } throw new IllegalArgumentException("Unknown Tensor dtype"); diff --git a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java index 5f512e8193852..4aa740c2c5ee5 100644 --- a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java +++ b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java @@ -7,6 +7,7 @@ import org.pytorch.Tensor; import java.nio.ByteBuffer; +import java.nio.FloatBuffer; import java.util.Locale; /** @@ -26,7 +27,7 @@ public final class TensorImageUtils { * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order */ public static Tensor bitmapToFloat32Tensor( - final Bitmap bitmap, float[] normMeanRGB, float normStdRGB[]) { + final Bitmap bitmap, final float[] normMeanRGB, final float normStdRGB[]) { checkNormMeanArg(normMeanRGB); checkNormStdArg(normStdRGB); @@ -35,8 +36,9 @@ public static Tensor bitmapToFloat32Tensor( } /** - * Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap}, - * normalized with specified in parameters mean and std. + * Writes tensor content from specified {@link android.graphics.Bitmap}, + * normalized with specified in parameters mean and std to specified {@link java.nio.FloatBuffer} + * with specified offset. * * @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data * @param x - x coordinate of top left corner of bitmap's area @@ -46,21 +48,23 @@ public static Tensor bitmapToFloat32Tensor( * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order */ - public static Tensor bitmapToFloat32Tensor( + public static void bitmapToFloatBuffer( final Bitmap bitmap, - int x, - int y, - int width, - int height, - float[] normMeanRGB, - float[] normStdRGB) { + final int x, + final int y, + final int width, + final int height, + final float[] normMeanRGB, + final float[] normStdRGB, + final FloatBuffer outBuffer, + final int outBufferOffset) { + checkOutBufferCapacity(outBuffer, outBufferOffset, width, height); checkNormMeanArg(normMeanRGB); checkNormStdArg(normStdRGB); final int pixelsCount = height * width; final int[] pixels = new int[pixelsCount]; bitmap.getPixels(pixels, 0, width, x, y, width, height); - final float[] floatArray = new float[3 * pixelsCount]; final int offset_g = pixelsCount; final int offset_b = 2 * pixelsCount; for (int i = 0; i < pixelsCount; i++) { @@ -68,11 +72,41 @@ public static Tensor bitmapToFloat32Tensor( float r = ((c >> 16) & 0xff) / 255.0f; float g = ((c >> 8) & 0xff) / 255.0f; float b = ((c) & 0xff) / 255.0f; - floatArray[i] = (r - normMeanRGB[0]) / normStdRGB[0]; - floatArray[offset_g + i] = (g - normMeanRGB[1]) / normStdRGB[1]; - floatArray[offset_b + i] = (b - normMeanRGB[2]) / normStdRGB[2]; + float rF = (r - normMeanRGB[0]) / normStdRGB[0]; + float gF = (g - normMeanRGB[1]) / normStdRGB[1]; + float bF = (b - normMeanRGB[2]) / normStdRGB[2]; + outBuffer.put(outBufferOffset + i, rF); + outBuffer.put(outBufferOffset + offset_g + i, gF); + outBuffer.put(outBufferOffset + offset_b + i, bF); } - return Tensor.newFloat32Tensor(new long[]{1, 3, height, width}, floatArray); + } + + /** + * Creates new {@link org.pytorch.Tensor} from specified area of {@link android.graphics.Bitmap}, + * normalized with specified in parameters mean and std. + * + * @param bitmap {@link android.graphics.Bitmap} as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order + */ + public static Tensor bitmapToFloat32Tensor( + final Bitmap bitmap, + int x, + int y, + int width, + int height, + float[] normMeanRGB, + float[] normStdRGB) { + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + + final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * width * height); + bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0); + return Tensor.fromBlob(floatBuffer, new long[]{1, 3, height, width}); } /** @@ -105,6 +139,52 @@ public static Tensor imageYUV420CenterCropToFloat32Tensor( checkRotateCWDegrees(rotateCWDegrees); checkTensorSize(tensorWidth, tensorHeight); + final FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(3 * tensorWidth * tensorHeight); + imageYUV420CenterCropToFloatBuffer( + image, + rotateCWDegrees, + tensorWidth, + tensorHeight, + normMeanRGB, normStdRGB, floatBuffer, 0); + return Tensor.fromBlob(floatBuffer, new long[]{1, 3, tensorHeight, tensorWidth}); + } + + /** + * Writes tensor content from specified {@link android.media.Image}, doing optional rotation, + * scaling (nearest) and center cropping to specified {@link java.nio.FloatBuffer} with specified offset. + * + * @param image {@link android.media.Image} as a source for Tensor data + * @param rotateCWDegrees Clockwise angle through which the input image needs to be rotated to be + * upright. Range of valid values: 0, 90, 180, 270 + * @param tensorWidth return tensor width, must be positive + * @param tensorHeight return tensor height, must be positive + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB order + * @param outBuffer Output buffer, where tensor content will be written + * @param outBufferOffset Output buffer offset with which tensor content will be written + */ + public static void imageYUV420CenterCropToFloatBuffer( + final Image image, + int rotateCWDegrees, + final int tensorWidth, + final int tensorHeight, + float[] normMeanRGB, + float[] normStdRGB, + final FloatBuffer outBuffer, + final int outBufferOffset) { + checkOutBufferCapacity(outBuffer, outBufferOffset, tensorWidth, tensorHeight); + + if (image.getFormat() != ImageFormat.YUV_420_888) { + throw new IllegalArgumentException( + String.format( + Locale.US, "Image format %d != ImageFormat.YUV_420_888", image.getFormat())); + } + + checkNormMeanArg(normMeanRGB); + checkNormStdArg(normStdRGB); + checkRotateCWDegrees(rotateCWDegrees); + checkTensorSize(tensorWidth, tensorHeight); + final int widthBeforeRotation = image.getWidth(); final int heightBeforeRotation = image.getHeight(); @@ -158,7 +238,6 @@ public static Tensor imageYUV420CenterCropToFloat32Tensor( final int channelSize = tensorHeight * tensorWidth; final int tensorInputOffsetG = channelSize; final int tensorInputOffsetB = 2 * channelSize; - final float[] floatArray = new float[3 * channelSize]; for (int x = 0; x < tensorWidth; x++) { for (int y = 0; y < tensorHeight; y++) { @@ -198,13 +277,22 @@ public static Tensor imageYUV420CenterCropToFloat32Tensor( int r = clamp((a0 + a1) >> 10, 0, 255); int g = clamp((a0 - a2 - a3) >> 10, 0, 255); int b = clamp((a0 + a4) >> 10, 0, 255); - final int offset = y * tensorWidth + x; - floatArray[offset] = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0]; - floatArray[tensorInputOffsetG + offset] = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1]; - floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2]; + final int offset = outBufferOffset + y * tensorWidth + x; + float rF = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0]; + float gF = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1]; + float bF = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2]; + + outBuffer.put(offset, rF); + outBuffer.put(offset + tensorInputOffsetG, gF); + outBuffer.put(offset + tensorInputOffsetB, bF); } } - return Tensor.newFloat32Tensor(new long[]{1, 3, tensorHeight, tensorWidth}, floatArray); + } + + private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) { + if (outBufferOffset + 3 * tensorWidth * tensorHeight > outBuffer.capacity()) { + throw new IllegalStateException("Buffer underflow"); + } } private static void checkTensorSize(int tensorWidth, int tensorHeight) { diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index 77ed7da438129..944491cf21ea7 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -55,11 +55,15 @@ add_subdirectory(src/THNN) IF(USE_ROCM) include(LoadHIP) if (NOT PYTORCH_FOUND_HIP) - MESSAGE(FATAL_ERROR - "Could not find HIP installation") + set(USE_ROCM OFF) endif() ENDIF() +# Both CUDA and ROCM are enabled and found. Report an error. +if(USE_CUDA AND USE_ROCM) + message(FATAL_ERROR "Both CUDA and ROCm are enabled and found. PyTorch can only be built with either of them. Please turn one off by using either USE_CUDA=OFF or USE_ROCM=OFF.") +endif() + IF(MSVC) # we want to respect the standard, and we are bored of those **** . ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index dc150b645b737..b6fe232896069 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -82,7 +82,7 @@ FILE(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip") FILE(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp") add_subdirectory(quantized) -set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp}) +set(all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp} ${native_sparse_cpp} ${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp} ${generated_cpp} ${core_generated_cpp} ${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${cpu_kernel_cpp}) if(AT_MKL_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp}) endif() @@ -114,6 +114,7 @@ endif() filter_list(generated_h generated_cpp "\\.h$") filter_list(cuda_generated_h cuda_generated_cpp "\\.h$") +filter_list(core_generated_h core_generated_cpp "\\.h$") # TODO: When we have hip_generated_cpp #filter_list(hip_generated_h hip_generated_cpp "\\.h$") @@ -459,6 +460,12 @@ FOREACH(HEADER ${generated_h} ${cuda_generated_h}) INSTALL(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen) ENDFOREACH() +message("AT_INSTALL_INCLUDE_DIR ${AT_INSTALL_INCLUDE_DIR}/ATen/core") +FOREACH(HEADER ${core_generated_h}) + message("core header install: ${HEADER}") + INSTALL(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen/core) +ENDFOREACH() + INSTALL(FILES ${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml DESTINATION ${AT_INSTALL_SHARE_DIR}/ATen) diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 5e9cf532ba723..4da3e87a68f85 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -762,20 +762,6 @@ output: True - THTensor* self ]] -[[ - name: _th_log10 - cname: log10 - types: - - floating_point - backends: - - CUDA - variants: function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] [[ name: _th_log1p cname: log1p diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index f7e51c83fd829..9dff0ee7f1d39 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -179,8 +179,9 @@ void set_num_threads(int nthreads) { TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); int no_value = NOT_SET; TORCH_CHECK(num_intraop_threads.compare_exchange_strong(no_value, nthreads), - "Error: cannot set number of interop threads " - "after parallel work has started or after set_num_threads call"); + "Error: cannot set number of intraop threads " + "after parallel work has started or after set_num_threads call " + "when using native parallel backend"); #else TORCH_CHECK(false, "set_num_threads is not supported for mobile."); #endif // C10_MOBILE diff --git a/aten/src/ATen/ParallelNativeTBB.cpp b/aten/src/ATen/ParallelNativeTBB.cpp index aa002120240f6..2fd6c0b363e30 100644 --- a/aten/src/ATen/ParallelNativeTBB.cpp +++ b/aten/src/ATen/ParallelNativeTBB.cpp @@ -21,11 +21,11 @@ namespace at { namespace { static thread_local tbb::task_scheduler_init tbb_init_(intraop_default_num_threads()); -std::atomic num_intraop_threads_{-1}; static thread_local tbb::task_group tg_; std::mutex global_thread_mutex_; std::shared_ptr global_thread_limit_ = nullptr; +std::atomic num_intraop_threads_{-1}; void _internal_set_num_threads(int nthreads) { TORCH_INTERNAL_ASSERT(nthreads > 0); @@ -33,6 +33,7 @@ void _internal_set_num_threads(int nthreads) { std::unique_lock lk(global_thread_mutex_); global_thread_limit_ = std::make_shared( tbb::global_control::max_allowed_parallelism, nthreads); + num_intraop_threads_.store(nthreads); } if (tbb_init_.is_active()) { tbb_init_.terminate(); @@ -59,14 +60,8 @@ void init_num_threads() { void set_num_threads(int nthreads) { TORCH_CHECK(nthreads > 0); - int no_value = -1; - if (num_intraop_threads_.compare_exchange_strong(no_value, nthreads)) { - _internal_set_num_threads(nthreads); - return; - } - TORCH_CHECK(false, - "Error: cannot set number of interop threads " - "after parallel work has started or after set_num_threads call"); + + _internal_set_num_threads(nthreads); } int get_num_threads() { diff --git a/aten/src/ATen/ParallelOpenMP.h b/aten/src/ATen/ParallelOpenMP.h index e1e4b9bf44332..fecb9858d37d8 100644 --- a/aten/src/ATen/ParallelOpenMP.h +++ b/aten/src/ATen/ParallelOpenMP.h @@ -25,7 +25,13 @@ inline void parallel_for( #ifdef _OPENMP std::atomic_flag err_flag = ATOMIC_FLAG_INIT; std::exception_ptr eptr; -#pragma omp parallel if (!omp_in_parallel() && ((end - begin) >= grain_size)) + // choose number of tasks based on grain size and number of threads + int64_t num_threads = omp_in_parallel() ? 1 : omp_get_max_threads(); + if (grain_size > 0) { + num_threads = std::min(num_threads, divup((end - begin), grain_size)); + } + +#pragma omp parallel num_threads(num_threads) { int64_t num_threads = omp_get_num_threads(); int64_t tid = omp_get_thread_num(); diff --git a/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp b/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp deleted file mode 100644 index d74609e0e988b..0000000000000 --- a/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp +++ /dev/null @@ -1,1492 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include - -// @generated by aten/src/ATen/gen.py - -// TODO Once all ATen ops are moved to c10, this file should be removed - -namespace at { - -namespace { -struct OpNameEquals final { - bool operator()(const std::pair& lhs, const std::pair& rhs) const { - return 0 == strcmp(lhs.first, rhs.first) && 0 == strcmp(lhs.second, rhs.second); - } -}; - -struct OpNameHash final { - size_t operator()(const std::pair& p) const { - // use std::hash because std::hash would hash pointers and not pointed-to strings - return std::hash()(p.first) ^ (~ std::hash()(p.second)); - } -}; -} - -bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName) { - static std::unordered_set, OpNameHash, OpNameEquals> ops { - {"aten::_cast_Byte", ""}, - {"aten::_cast_Char", ""}, - {"aten::_cast_Double", ""}, - {"aten::_cast_Float", ""}, - {"aten::_cast_Int", ""}, - {"aten::_cast_Long", ""}, - {"aten::_cast_Short", ""}, - {"aten::_cast_Half", ""}, - {"aten::data", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::align_as", ""}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::align_tensors", ""}, - #endif - {"aten::_cudnn_ctc_loss", ""}, - {"aten::_cudnn_rnn_flatten_weight", ""}, - {"aten::_debug_has_internal_overlap", ""}, - {"aten::_fused_dropout", ""}, - {"aten::_masked_scale", ""}, - {"aten::_sobol_engine_ff_", ""}, - {"aten::_sobol_engine_scramble_", ""}, - {"aten::_sobol_engine_initialize_state_", ""}, - {"aten::_reshape_from_tensor", ""}, - {"aten::_shape_as_tensor", ""}, - {"aten::dropout", ""}, - {"aten::dropout_", ""}, - {"aten::feature_dropout", ""}, - {"aten::feature_dropout_", ""}, - {"aten::alpha_dropout", ""}, - {"aten::alpha_dropout_", ""}, - {"aten::feature_alpha_dropout", ""}, - {"aten::feature_alpha_dropout_", ""}, - {"aten::abs", ""}, - {"aten::abs_", ""}, - {"aten::acos", ""}, - {"aten::acos_", ""}, - {"aten::avg_pool1d", ""}, - {"aten::adaptive_avg_pool1d", ""}, - {"aten::adaptive_max_pool1d", ""}, - {"aten::add", "Tensor"}, - {"aten::add_", "Tensor"}, - {"aten::add", "Scalar"}, - {"aten::add_", "Scalar"}, - {"aten::addmv", ""}, - {"aten::addmv_", ""}, - {"aten::addr", ""}, - {"aten::addr_", ""}, - {"aten::affine_grid_generator", ""}, - {"aten::affine_grid_generator_backward", ""}, - {"aten::all", "dim"}, - {"aten::allclose", ""}, - {"aten::any", "dim"}, - {"aten::_dim_arange", ""}, - {"aten::argmax", ""}, - {"aten::argmin", ""}, - {"aten::as_strided", ""}, - {"aten::as_strided_", ""}, - {"aten::asin", ""}, - {"aten::asin_", ""}, - {"aten::atan", ""}, - {"aten::atan_", ""}, - {"aten::baddbmm", ""}, - {"aten::baddbmm_", ""}, - {"aten::_baddbmm_mkl_", ""}, - {"aten::bernoulli", ""}, - {"aten::bernoulli_", "Tensor"}, - {"aten::bernoulli_", "float"}, - {"aten::bernoulli", "p"}, - {"aten::bitwise_not", ""}, - {"aten::bitwise_not_", ""}, - {"aten::logical_not", ""}, - {"aten::logical_not_", ""}, - {"aten::logical_xor", ""}, - {"aten::logical_xor_", ""}, - {"aten::bmm", ""}, - {"aten::broadcast_tensors", ""}, - {"aten::cat", ""}, - {"aten::ceil", ""}, - {"aten::ceil_", ""}, - {"aten::chain_matmul", ""}, - {"aten::chunk", ""}, - {"aten::clamp", ""}, - {"aten::clamp_", ""}, - {"aten::clamp_max", ""}, - {"aten::clamp_max_", ""}, - {"aten::clamp_min", ""}, - {"aten::clamp_min_", ""}, - {"aten::cudnn_is_acceptable", ""}, - {"aten::constant_pad_nd", ""}, - {"aten::conv_tbc", ""}, - {"aten::conv_tbc_backward", ""}, - {"aten::copy_", ""}, - {"aten::_copy_from", ""}, - {"aten::cos", ""}, - {"aten::cos_", ""}, - {"aten::cosh", ""}, - {"aten::cosh_", ""}, - {"aten::cosine_embedding_loss", ""}, - {"aten::cudnn_affine_grid_generator", ""}, - {"aten::cudnn_affine_grid_generator_backward", ""}, - {"aten::cudnn_convolution_backward_input", ""}, - {"aten::cudnn_convolution_backward", ""}, - {"aten::cudnn_convolution_backward_bias", ""}, - {"aten::cudnn_convolution_backward_weight", ""}, - {"aten::cudnn_convolution_transpose_backward", ""}, - {"aten::cudnn_convolution_transpose_backward_bias", ""}, - {"aten::cudnn_convolution_transpose_backward_input", ""}, - {"aten::cudnn_convolution_transpose_backward_weight", ""}, - {"aten::cudnn_grid_sampler", ""}, - {"aten::cudnn_grid_sampler_backward", ""}, - {"aten::ctc_loss", "IntList"}, - {"aten::ctc_loss", "Tensor"}, - {"aten::_ctc_loss", ""}, - {"aten::_ctc_loss_backward", ""}, - {"aten::det", ""}, - {"aten::diag_embed", ""}, - {"aten::diagflat", ""}, - {"aten::diagonal", ""}, - {"aten::fill_diagonal_", ""}, - {"aten::div", "Tensor"}, - {"aten::div_", "Tensor"}, - {"aten::div", "Scalar"}, - {"aten::div_", "Scalar"}, - {"aten::dot", ""}, - {"aten::einsum", ""}, - {"aten::embedding", ""}, - {"aten::embedding_backward", ""}, - {"aten::embedding_dense_backward", ""}, - {"aten::embedding_renorm_", ""}, - {"aten::embedding_sparse_backward", ""}, - {"aten::_embedding_bag_per_sample_weights_backward", ""}, - {"aten::resize_", ""}, - {"aten::empty_like", ""}, - {"aten::erf", ""}, - {"aten::erf_", ""}, - {"aten::erfc", ""}, - {"aten::erfc_", ""}, - {"aten::exp", ""}, - {"aten::exp_", ""}, - {"aten::expm1", ""}, - {"aten::expm1_", ""}, - {"aten::expand", ""}, - {"aten::expand_as", ""}, - {"aten::flatten", "using_ints"}, - {"aten::fill_", "Scalar"}, - {"aten::fill_", "Tensor"}, - {"aten::floor", ""}, - {"aten::floor_", ""}, - {"aten::frac", ""}, - {"aten::frac_", ""}, - {"aten::full_like", ""}, - {"aten::grid_sampler", ""}, - {"aten::grid_sampler_2d", ""}, - {"aten::grid_sampler_2d_backward", ""}, - {"aten::grid_sampler_3d", ""}, - {"aten::grid_sampler_3d_backward", ""}, - {"aten::hinge_embedding_loss", ""}, - {"aten::ger", ""}, - {"aten::fft", ""}, - {"aten::ifft", ""}, - {"aten::rfft", ""}, - {"aten::irfft", ""}, - {"aten::_fft_with_size", ""}, - {"aten::_cufft_get_plan_cache_size", ""}, - {"aten::_cufft_get_plan_cache_max_size", ""}, - {"aten::index_copy_", ""}, - {"aten::index_copy", ""}, - {"aten::inverse", ""}, - {"aten::_inverse_helper", ""}, - {"aten::isclose", ""}, - {"aten::isnan", ""}, - {"aten::is_distributed", ""}, - {"aten::is_floating_point", ""}, - {"aten::is_complex", ""}, - {"aten::is_nonzero", ""}, - {"aten::is_same_size", ""}, - {"aten::is_signed", ""}, - {"aten::kl_div", ""}, - {"aten::kl_div_backward", ""}, - {"aten::kthvalue", ""}, - {"aten::fbgemm_linear_int8_weight_fp32_activation", ""}, - {"aten::fbgemm_linear_int8_weight", ""}, - {"aten::fbgemm_linear_quantize_weight", ""}, - {"aten::fbgemm_pack_gemm_matrix_fp16", ""}, - {"aten::fbgemm_linear_fp16_weight_fp32_activation", ""}, - {"aten::fbgemm_linear_fp16_weight", ""}, - {"aten::fbgemm_pack_quantized_matrix", ""}, - {"aten::fbgemm_pack_quantized_matrix", "KN"}, - {"aten::log", ""}, - {"aten::log_", ""}, - {"aten::log10", ""}, - {"aten::log10_", ""}, - {"aten::log1p", ""}, - {"aten::log1p_", ""}, - {"aten::log2", ""}, - {"aten::log2_", ""}, - {"aten::logdet", ""}, - {"aten::_log_softmax", ""}, - {"aten::_log_softmax_backward_data", ""}, - {"aten::logsumexp", ""}, - {"aten::margin_ranking_loss", ""}, - {"aten::matmul", ""}, - {"aten::matrix_rank", "tol"}, - {"aten::matrix_rank", ""}, - {"aten::matrix_power", ""}, - {"aten::max", "dim"}, - {"aten::max_values", ""}, - {"aten::max_pool1d_with_indices", ""}, - {"aten::max_pool1d", ""}, - {"aten::max_pool2d", ""}, - {"aten::mkldnn_max_pool2d", ""}, - {"aten::quantized_max_pool2d", ""}, - {"aten::max_pool3d", ""}, - {"aten::median", "dim"}, - {"aten::min", "dim"}, - {"aten::min_values", ""}, - {"aten::mkldnn_convolution_backward_input", ""}, - {"aten::mkldnn_convolution_backward_weights", ""}, - {"aten::mkldnn_convolution_backward", ""}, - {"aten::miopen_convolution_backward_input", ""}, - {"aten::miopen_convolution_backward", ""}, - {"aten::miopen_convolution_backward_bias", ""}, - {"aten::miopen_convolution_backward_weight", ""}, - {"aten::miopen_convolution_transpose_backward", ""}, - {"aten::miopen_convolution_transpose_backward_input", ""}, - {"aten::miopen_convolution_transpose_backward_weight", ""}, - {"aten::miopen_depthwise_convolution_backward_input", ""}, - {"aten::miopen_depthwise_convolution_backward", ""}, - {"aten::miopen_depthwise_convolution_backward_weight", ""}, - {"aten::mm", ""}, - {"aten::_sparse_mm", ""}, - {"aten::mode", ""}, - {"aten::mul", "Tensor"}, - {"aten::mul_", "Tensor"}, - {"aten::mul", "Scalar"}, - {"aten::mul_", "Scalar"}, - {"aten::mv", ""}, - {"aten::mvlgamma", ""}, - {"aten::mvlgamma_", ""}, - {"aten::narrow_copy", ""}, - {"aten::narrow", ""}, - {"aten::batch_norm_stats", ""}, - {"aten::_nnpack_available", ""}, - {"aten::_nnpack_spatial_convolution_backward", ""}, - {"aten::_nnpack_spatial_convolution_backward_input", ""}, - {"aten::_nnpack_spatial_convolution_backward_weight", ""}, - {"aten::ones_like", ""}, - {"aten::pairwise_distance", ""}, - {"aten::cdist", ""}, - {"aten::_cdist_backward", ""}, - {"aten::pdist", ""}, - {"aten::_pdist_forward", ""}, - {"aten::_pdist_backward", ""}, - {"aten::cosine_similarity", ""}, - {"aten::permute", ""}, - {"aten::numpy_T", ""}, - {"aten::pixel_shuffle", ""}, - {"aten::is_pinned", ""}, - {"aten::pin_memory", ""}, - {"aten::pinverse", ""}, - {"aten::poisson_nll_loss", ""}, - {"aten::rand_like", ""}, - {"aten::randint_like", ""}, - {"aten::randint_like", "low"}, - {"aten::randn_like", ""}, - {"aten::reciprocal", ""}, - {"aten::reciprocal_", ""}, - {"aten::neg", ""}, - {"aten::neg_", ""}, - {"aten::repeat", ""}, - {"aten::repeat_interleave", "Tensor"}, - {"aten::repeat_interleave", "self_Tensor"}, - {"aten::repeat_interleave", "self_int"}, - {"aten::reshape", ""}, - {"aten::_mkldnn_reshape", ""}, - {"aten::reshape_as", ""}, - {"aten::round", ""}, - {"aten::round_", ""}, - {"aten::rrelu", ""}, - {"aten::rrelu_", ""}, - {"aten::relu", ""}, - {"aten::relu_", ""}, - {"aten::prelu", ""}, - {"aten::prelu_backward", ""}, - {"aten::gelu", ""}, - {"aten::gelu_backward", ""}, - {"aten::hardshrink", ""}, - {"aten::hardshrink_backward", ""}, - {"aten::rsqrt", ""}, - {"aten::rsqrt_", ""}, - {"aten::select", "int"}, - {"aten::selu", ""}, - {"aten::selu_", ""}, - {"aten::celu", ""}, - {"aten::celu_", ""}, - {"aten::sigmoid", ""}, - {"aten::sigmoid_", ""}, - {"aten::sin", ""}, - {"aten::sin_", ""}, - {"aten::sinh", ""}, - {"aten::sinh_", ""}, - {"aten::detach", ""}, - {"aten::detach_", ""}, - {"aten::size", "int"}, - {"aten::slice", "Tensor"}, - {"aten::slogdet", ""}, - {"aten::smm", ""}, - {"aten::_softmax", ""}, - {"aten::_softmax_backward_data", ""}, - {"aten::split", "Tensor"}, - {"aten::split_with_sizes", ""}, - {"aten::squeeze", ""}, - {"aten::squeeze", "dim"}, - {"aten::squeeze_", ""}, - {"aten::squeeze_", "dim"}, - {"aten::sspaddmm", ""}, - {"aten::stack", ""}, - {"aten::stride", "int"}, - {"aten::sum_to_size", ""}, - {"aten::sqrt", ""}, - {"aten::sqrt_", ""}, - {"aten::std", ""}, - {"aten::std", "dim"}, - {"aten::std_mean", ""}, - {"aten::std_mean", "dim"}, - {"aten::t", ""}, - {"aten::t_", ""}, - {"aten::tan", ""}, - {"aten::tan_", ""}, - {"aten::tanh", ""}, - {"aten::tanh_", ""}, - {"aten::tensordot", ""}, - {"aten::threshold", ""}, - {"aten::threshold_", ""}, - {"aten::threshold_backward", ""}, - {"aten::transpose", "int"}, - {"aten::_mkldnn_transpose", ""}, - {"aten::transpose_", ""}, - {"aten::_mkldnn_transpose_", ""}, - {"aten::one_hot", ""}, - {"aten::flip", ""}, - {"aten::roll", ""}, - {"aten::rot90", ""}, - {"aten::trapz", "x"}, - {"aten::trapz", "dx"}, - {"aten::_trilinear", ""}, - {"aten::triplet_margin_loss", ""}, - {"aten::trunc", ""}, - {"aten::trunc_", ""}, - {"aten::type_as", ""}, - {"aten::_has_compatible_shallow_copy_type", ""}, - {"aten::_unique", ""}, - {"aten::unique_dim", ""}, - {"aten::unique_consecutive", ""}, - {"aten::unique_dim_consecutive", ""}, - {"aten::_unique2", ""}, - {"aten::_unsafe_view", ""}, - {"aten::unsqueeze", ""}, - {"aten::unsqueeze_", ""}, - {"aten::var", ""}, - {"aten::var", "dim"}, - {"aten::var_mean", ""}, - {"aten::var_mean", "dim"}, - {"aten::view_as", ""}, - {"aten::where", "self"}, - {"aten::where", ""}, - {"aten::_s_where", ""}, - {"aten::norm_except_dim", ""}, - {"aten::_weight_norm", ""}, - {"aten::_weight_norm_cuda_interface", ""}, - {"aten::_weight_norm_cuda_interface_backward", ""}, - {"aten::_weight_norm_differentiable_backward", ""}, - {"aten::zeros_like", ""}, - {"aten::_standard_gamma_grad", ""}, - {"aten::_standard_gamma", ""}, - {"aten::_dirichlet_grad", ""}, - {"aten::_sample_dirichlet", ""}, - {"aten::poisson", ""}, - {"aten::native_norm", ""}, - {"aten::_sparse_sum", ""}, - {"aten::_sparse_sum", "dim"}, - {"aten::_sparse_sum_backward", ""}, - {"aten::norm", "Scalar"}, - {"aten::norm", "ScalarOpt_dim"}, - {"aten::frobenius_norm", ""}, - {"aten::frobenius_norm", "dim"}, - {"aten::nuclear_norm", ""}, - {"aten::nuclear_norm", "dim"}, - {"aten::clone", ""}, - {"aten::resize_as_", ""}, - {"aten::pow", "Tensor_Scalar"}, - {"aten::zero_", ""}, - {"aten::sub", "Tensor"}, - {"aten::sub_", "Tensor"}, - {"aten::sub", "Scalar"}, - {"aten::sub_", "Scalar"}, - {"aten::rsub", "Tensor"}, - {"aten::rsub", "Scalar"}, - {"aten::_sparse_addmm", ""}, - {"aten::addmm", ""}, - {"aten::addmm_", ""}, - {"aten::sparse_resize_", ""}, - {"aten::sparse_resize_and_clear_", ""}, - {"aten::sparse_mask", ""}, - {"aten::to_dense", ""}, - {"aten::to_dense_backward", ""}, - {"aten::sparse_dim", ""}, - {"aten::_dimI", ""}, - {"aten::dense_dim", ""}, - {"aten::_dimV", ""}, - {"aten::_nnz", ""}, - {"aten::coalesce", ""}, - {"aten::is_coalesced", ""}, - {"aten::_indices", ""}, - {"aten::_values", ""}, - {"aten::_coalesced_", ""}, - {"aten::indices", ""}, - {"aten::values", ""}, - {"aten::hspmm", ""}, - {"aten::copy_sparse_to_sparse_", ""}, - {"aten::numel", ""}, - {"aten::unbind", "int"}, - {"aten::to_sparse", "sparse_dim"}, - {"aten::to_sparse", ""}, - {"aten::to_mkldnn", ""}, - {"aten::mkldnn_reorder_conv2d_weight", ""}, - {"aten::to_mkldnn_backward", ""}, - {"aten::dequantize", ""}, - {"aten::q_scale", ""}, - {"aten::q_zero_point", ""}, - {"aten::q_per_channel_scales", ""}, - {"aten::q_per_channel_zero_points", ""}, - {"aten::int_repr", ""}, - {"aten::_make_per_tensor_quantized_tensor", ""}, - {"aten::_make_per_channel_quantized_tensor", ""}, - {"aten::fake_quantize_per_tensor_affine", ""}, - {"aten::fake_quantize_per_tensor_affine_backward", ""}, - {"aten::to", "other"}, - {"aten::meshgrid", ""}, - {"aten::cartesian_prod", ""}, - {"aten::combinations", ""}, - {"aten::item", ""}, - {"aten::_local_scalar_dense", ""}, - {"aten::_thnn_fused_gru_cell_backward", ""}, - {"aten::lstm", "input"}, - {"aten::lstm", "data"}, - {"aten::gru", "input"}, - {"aten::gru", "data"}, - {"aten::rnn_tanh", "input"}, - {"aten::rnn_tanh", "data"}, - {"aten::rnn_relu", "input"}, - {"aten::rnn_relu", "data"}, - {"aten::quantized_gru", "input"}, - {"aten::quantized_gru", "data"}, - {"aten::quantized_lstm_cell", ""}, - {"aten::quantized_gru_cell", ""}, - {"aten::quantized_rnn_relu_cell", ""}, - {"aten::quantized_rnn_tanh_cell", ""}, - {"aten::_pack_padded_sequence", ""}, - {"aten::_pack_padded_sequence_backward", ""}, - {"aten::_pad_packed_sequence", ""}, - {"aten::set_", "source_Tensor"}, - {"aten::set_", ""}, - {"aten::is_set_to", ""}, - {"aten::masked_fill_", "Scalar"}, - {"aten::masked_fill", "Scalar"}, - {"aten::masked_fill_", "Tensor"}, - {"aten::masked_fill", "Tensor"}, - {"aten::masked_scatter_", ""}, - {"aten::masked_scatter", ""}, - {"aten::view", ""}, - {"aten::put_", ""}, - {"aten::index_add_", ""}, - {"aten::index_add", ""}, - {"aten::index_fill_", "Scalar"}, - {"aten::index_fill", "Scalar"}, - {"aten::index_fill_", "Tensor"}, - {"aten::index_fill", "Tensor"}, - {"aten::scatter_", "src"}, - {"aten::scatter", "src"}, - {"aten::scatter_", "value"}, - {"aten::scatter", "value"}, - {"aten::scatter_add_", ""}, - {"aten::scatter_add", ""}, - {"aten::lt_", "Scalar"}, - {"aten::lt_", "Tensor"}, - {"aten::gt_", "Scalar"}, - {"aten::gt_", "Tensor"}, - {"aten::le_", "Scalar"}, - {"aten::le_", "Tensor"}, - {"aten::ge_", "Scalar"}, - {"aten::ge_", "Tensor"}, - {"aten::eq_", "Scalar"}, - {"aten::eq_", "Tensor"}, - {"aten::ne_", "Scalar"}, - {"aten::ne_", "Tensor"}, - {"aten::__and__", "Scalar"}, - {"aten::__and__", "Tensor"}, - {"aten::__iand__", "Scalar"}, - {"aten::__iand__", "Tensor"}, - {"aten::__or__", "Scalar"}, - {"aten::__or__", "Tensor"}, - {"aten::__ior__", "Scalar"}, - {"aten::__ior__", "Tensor"}, - {"aten::__xor__", "Scalar"}, - {"aten::__xor__", "Tensor"}, - {"aten::__ixor__", "Scalar"}, - {"aten::__ixor__", "Tensor"}, - {"aten::__lshift__", "Scalar"}, - {"aten::__lshift__", "Tensor"}, - {"aten::__ilshift__", "Scalar"}, - {"aten::__ilshift__", "Tensor"}, - {"aten::__rshift__", "Scalar"}, - {"aten::__rshift__", "Tensor"}, - {"aten::__irshift__", "Scalar"}, - {"aten::__irshift__", "Tensor"}, - {"aten::lgamma_", ""}, - {"aten::atan2_", ""}, - {"aten::tril_", ""}, - {"aten::triu_", ""}, - {"aten::digamma_", ""}, - {"aten::polygamma_", ""}, - {"aten::renorm_", ""}, - {"aten::pow_", "Scalar"}, - {"aten::pow_", "Tensor"}, - {"aten::lerp_", "Scalar"}, - {"aten::lerp_", "Tensor"}, - {"aten::fmod_", "Scalar"}, - {"aten::fmod_", "Tensor"}, - {"aten::remainder_", "Scalar"}, - {"aten::remainder_", "Tensor"}, - {"aten::addbmm_", ""}, - {"aten::addbmm", ""}, - {"aten::addcdiv_", ""}, - {"aten::random_", "from"}, - {"aten::random_", "to"}, - {"aten::random_", ""}, - {"aten::uniform_", ""}, - {"aten::normal_", ""}, - {"aten::cauchy_", ""}, - {"aten::log_normal_", ""}, - {"aten::exponential_", ""}, - {"aten::geometric_", ""}, - {"aten::diag", ""}, - {"aten::cross", ""}, - {"aten::triu", ""}, - {"aten::tril", ""}, - {"aten::trace", ""}, - {"aten::ne", "Scalar"}, - {"aten::ne", "Tensor"}, - {"aten::eq", "Scalar"}, - {"aten::eq", "Tensor"}, - {"aten::ge", "Scalar"}, - {"aten::ge", "Tensor"}, - {"aten::le", "Scalar"}, - {"aten::le", "Tensor"}, - {"aten::gt", "Scalar"}, - {"aten::gt", "Tensor"}, - {"aten::lt", "Scalar"}, - {"aten::lt", "Tensor"}, - {"aten::take", ""}, - {"aten::index_select", ""}, - {"aten::masked_select", ""}, - {"aten::nonzero", ""}, - {"aten::nonzero_numpy", ""}, - {"aten::gather", ""}, - {"aten::_gather_sparse_backward", ""}, - {"aten::addcmul", ""}, - {"aten::addcmul_", ""}, - {"aten::addcdiv", ""}, - {"aten::lstsq", ""}, - {"aten::triangular_solve", ""}, - {"aten::_triangular_solve_helper", ""}, - {"aten::symeig", ""}, - {"aten::_symeig_helper", ""}, - {"aten::eig", ""}, - {"aten::svd", ""}, - {"aten::_svd_helper", ""}, - {"aten::cholesky", ""}, - {"aten::_cholesky_helper", ""}, - {"aten::cholesky_solve", ""}, - {"aten::_cholesky_solve_helper", ""}, - {"aten::solve", ""}, - {"aten::_solve_helper", ""}, - {"aten::cholesky_inverse", ""}, - {"aten::qr", ""}, - {"aten::_qr_helper", ""}, - {"aten::geqrf", ""}, - {"aten::orgqr", ""}, - {"aten::ormqr", ""}, - {"aten::_lu_with_info", ""}, - {"aten::lu_solve", ""}, - {"aten::_lu_solve_helper", ""}, - {"aten::multinomial", ""}, - {"aten::_multinomial_alias_setup", ""}, - {"aten::_multinomial_alias_draw", ""}, - {"aten::lgamma", ""}, - {"aten::digamma", ""}, - {"aten::polygamma", ""}, - {"aten::erfinv", ""}, - {"aten::erfinv_", ""}, - {"aten::sign", ""}, - {"aten::sign_", ""}, - {"aten::dist", ""}, - {"aten::atan2", ""}, - {"aten::lerp", "Scalar"}, - {"aten::lerp", "Tensor"}, - {"aten::histc", ""}, - {"aten::fmod", "Scalar"}, - {"aten::fmod", "Tensor"}, - {"aten::remainder", "Scalar"}, - {"aten::remainder", "Tensor"}, - {"aten::min", "other"}, - {"aten::min", ""}, - {"aten::max", "other"}, - {"aten::max", ""}, - {"aten::median", ""}, - {"aten::sort", ""}, - {"aten::argsort", ""}, - {"aten::topk", ""}, - {"aten::all", ""}, - {"aten::any", ""}, - {"aten::renorm", ""}, - {"aten::unfold", ""}, - {"aten::equal", ""}, - {"aten::pow", "Tensor_Tensor"}, - {"aten::pow", "Scalar"}, - {"aten::normal", "Tensor_float"}, - {"aten::normal", "float_Tensor"}, - {"aten::normal", "Tensor_Tensor"}, - {"aten::alias", ""}, - {"aten::_addr", ""}, - {"aten::_addr_", ""}, - {"aten::_index_copy_", ""}, - {"aten::_cumsum", ""}, - {"aten::_cumprod", ""}, - {"aten::_var", ""}, - {"aten::_std", ""}, - {"aten::_cat", ""}, - {"aten::_mode", ""}, - {"aten::_max", ""}, - {"aten::_min", ""}, - {"aten::mse_loss", ""}, - {"aten::mse_loss_backward", ""}, - {"aten::l1_loss", ""}, - {"aten::l1_loss_backward", ""}, - {"aten::multilabel_margin_loss", ""}, - {"aten::multilabel_margin_loss_forward", ""}, - {"aten::multilabel_margin_loss_backward", ""}, - {"aten::smooth_l1_loss", ""}, - {"aten::smooth_l1_loss_backward", ""}, - {"aten::soft_margin_loss", ""}, - {"aten::soft_margin_loss_backward", ""}, - {"aten::elu", ""}, - {"aten::elu_backward", ""}, - {"aten::elu_", ""}, - {"aten::glu", ""}, - {"aten::glu_backward", ""}, - {"aten::hardtanh", ""}, - {"aten::hardtanh_backward", ""}, - {"aten::hardtanh_", ""}, - {"aten::leaky_relu", ""}, - {"aten::leaky_relu_backward", ""}, - {"aten::leaky_relu_", ""}, - {"aten::log_sigmoid", ""}, - {"aten::log_sigmoid_forward", ""}, - {"aten::log_sigmoid_backward", ""}, - {"aten::rrelu_with_noise", ""}, - {"aten::rrelu_with_noise_backward", ""}, - {"aten::rrelu_with_noise_", ""}, - {"aten::softplus", ""}, - {"aten::softplus_backward", ""}, - {"aten::softshrink", ""}, - {"aten::softshrink_backward", ""}, - {"aten::adaptive_avg_pool2d", ""}, - {"aten::mkldnn_adaptive_avg_pool2d", ""}, - {"aten::_adaptive_avg_pool2d", ""}, - {"aten::_adaptive_avg_pool2d_backward", ""}, - {"aten::adaptive_avg_pool3d", ""}, - {"aten::adaptive_avg_pool3d_backward", ""}, - {"aten::adaptive_max_pool2d", ""}, - {"aten::adaptive_max_pool2d_backward", ""}, - {"aten::adaptive_max_pool3d", ""}, - {"aten::adaptive_max_pool3d_backward", ""}, - {"aten::avg_pool2d", ""}, - {"aten::avg_pool2d_backward", ""}, - {"aten::avg_pool3d", ""}, - {"aten::avg_pool3d_backward", ""}, - {"aten::fractional_max_pool2d", ""}, - {"aten::fractional_max_pool2d_backward", ""}, - {"aten::fractional_max_pool3d", ""}, - {"aten::fractional_max_pool3d_backward", ""}, - {"aten::max_pool2d_with_indices", ""}, - {"aten::max_pool2d_with_indices_backward", ""}, - {"aten::max_pool3d_with_indices", ""}, - {"aten::max_pool3d_with_indices_backward", ""}, - {"aten::max_unpool2d", ""}, - {"aten::max_unpool2d_backward", ""}, - {"aten::max_unpool3d", ""}, - {"aten::max_unpool3d_backward", ""}, - {"aten::reflection_pad1d", ""}, - {"aten::reflection_pad1d_backward", ""}, - {"aten::reflection_pad2d", ""}, - {"aten::reflection_pad2d_backward", ""}, - {"aten::replication_pad1d", ""}, - {"aten::replication_pad1d_backward", ""}, - {"aten::replication_pad2d", ""}, - {"aten::replication_pad2d_backward", ""}, - {"aten::replication_pad3d", ""}, - {"aten::replication_pad3d_backward", ""}, - {"aten::upsample_linear1d", ""}, - {"aten::upsample_linear1d_backward", ""}, - {"aten::upsample_bilinear2d", ""}, - {"aten::upsample_bilinear2d_backward", ""}, - {"aten::upsample_bicubic2d", ""}, - {"aten::upsample_bicubic2d_backward", ""}, - {"aten::upsample_trilinear3d", ""}, - {"aten::upsample_trilinear3d_backward", ""}, - {"aten::upsample_nearest1d", ""}, - {"aten::upsample_nearest1d_backward", ""}, - {"aten::upsample_nearest2d", ""}, - {"aten::upsample_nearest2d_backward", ""}, - {"aten::upsample_nearest3d", ""}, - {"aten::upsample_nearest3d_backward", ""}, - {"aten::sigmoid_backward", ""}, - {"aten::tanh_backward", ""}, - {"aten::slow_conv_transpose2d_backward", "output_mask"}, - {"aten::slow_conv_transpose3d_backward", "output_mask"}, - {"aten::thnn_conv2d_backward", "output_mask"}, - {"aten::thnn_conv_depthwise2d_backward", "output_mask"}, - {"aten::thnn_conv3d_backward", "output_mask"}, - {"aten::slow_conv_dilated2d_backward", ""}, - {"aten::slow_conv_dilated3d_backward", ""}, - {"aten::col2im", ""}, - {"aten::col2im_backward", ""}, - {"aten::im2col", ""}, - {"aten::im2col_backward", ""}, - {"", ""} - }; - return ops.count(std::make_pair(opName.name.c_str(), opName.overload_name.c_str())) != 0; -} - -bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) { - static std::unordered_set, OpNameHash, OpNameEquals> ops { - {"aten::backward", ""}, - {"aten::set_data", ""}, - {"aten::is_leaf", ""}, - {"aten::output_nr", ""}, - {"aten::_version", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::rename_", ""}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::rename", ""}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::align_to", ""}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::refine_names", ""}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::unflatten", ""}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::unflatten", ""}, - #endif - {"aten::_cudnn_rnn", ""}, - {"aten::_cudnn_rnn_backward", ""}, - {"aten::_cudnn_init_dropout_state", ""}, - {"aten::_sobol_engine_draw", ""}, - {"aten::abs", "out"}, - {"aten::acos", "out"}, - {"aten::add", "out"}, - {"aten::addmv", "out"}, - {"aten::addr", "out"}, - {"aten::all", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::all", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::all", "dimname_out"}, - #endif - {"aten::any", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::any", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::any", "dimname_out"}, - #endif - {"aten::arange", ""}, - {"aten::arange", "start"}, - {"aten::arange", "start_step"}, - {"aten::arange", "out"}, - {"aten::arange", "start_out"}, - {"aten::asin", "out"}, - {"aten::atan", "out"}, - {"aten::baddbmm", "out"}, - {"aten::bartlett_window", ""}, - {"aten::bartlett_window", "periodic"}, - {"aten::batch_norm", ""}, - {"aten::_batch_norm_impl_index", ""}, - {"aten::_batch_norm_impl_index_backward", ""}, - {"aten::bernoulli", "out"}, - {"aten::bilinear", ""}, - {"aten::binary_cross_entropy_with_logits", ""}, - {"aten::binary_cross_entropy_with_logits_backward", ""}, - {"aten::bincount", ""}, - {"aten::bitwise_not", "out"}, - {"aten::logical_not", "out"}, - {"aten::logical_xor", "out"}, - {"aten::blackman_window", ""}, - {"aten::blackman_window", "periodic"}, - {"aten::bmm", "out"}, - {"aten::cat", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::cat", "names"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::cat", "names_out"}, - #endif - {"aten::ceil", "out"}, - {"aten::clamp", "out"}, - {"aten::clamp_max", "out"}, - {"aten::clamp_min", "out"}, - {"aten::contiguous", ""}, - {"aten::convolution", ""}, - {"aten::convolution_overrideable", ""}, - {"aten::convolution_backward_overrideable", ""}, - {"aten::_convolution", ""}, - {"aten::_convolution_nogroup", ""}, - {"aten::_convolution_double_backward", ""}, - {"aten::conv1d", ""}, - {"aten::conv2d", ""}, - {"aten::conv3d", ""}, - {"aten::conv_transpose1d", ""}, - {"aten::conv_transpose2d", "input"}, - {"aten::conv_transpose3d", "input"}, - {"aten::cos", "out"}, - {"aten::cosh", "out"}, - {"aten::cudnn_batch_norm", ""}, - {"aten::cudnn_batch_norm_backward", ""}, - {"aten::cudnn_convolution", ""}, - {"aten::cudnn_convolution_transpose", ""}, - {"aten::cumsum", ""}, - {"aten::cumsum", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::cumsum", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::cumsum", "dimname_out"}, - #endif - {"aten::cumprod", ""}, - {"aten::cumprod", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::cumprod", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::cumprod", "dimname_out"}, - #endif - {"aten::div", "out"}, - {"aten::dot", "out"}, - {"aten::embedding_bag", ""}, - {"aten::_embedding_bag", ""}, - {"aten::_embedding_bag_backward", ""}, - {"aten::_embedding_bag_sparse_backward", ""}, - {"aten::_embedding_bag_dense_backward", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::empty", "names"}, - #endif - {"aten::empty", "memory_format"}, - {"aten::new_empty", ""}, - {"aten::new_full", ""}, - {"aten::_empty_affine_quantized", ""}, - {"aten::_empty_per_channel_affine_quantized", ""}, - {"aten::empty", "out"}, - {"aten::empty_like", "dtype"}, - {"aten::empty_strided", ""}, - {"aten::erf", "out"}, - {"aten::erfc", "out"}, - {"aten::exp", "out"}, - {"aten::expm1", "out"}, - {"aten::eye", ""}, - {"aten::eye", "m"}, - {"aten::eye", "out"}, - {"aten::eye", "m_out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::flatten", "named_out_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::flatten", "using_names"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::flatten", "DimnameList"}, - #endif - {"aten::floor", "out"}, - {"aten::frac", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::full", "names"}, - #endif - {"aten::full", ""}, - {"aten::full", "out"}, - {"aten::full_like", "dtype"}, - {"aten::from_file", ""}, - {"aten::hann_window", ""}, - {"aten::hann_window", "periodic"}, - {"aten::hamming_window", ""}, - {"aten::hamming_window", "periodic"}, - {"aten::hamming_window", "periodic_alpha"}, - {"aten::hamming_window", "periodic_alpha_beta"}, - {"aten::ger", "out"}, - {"aten::group_norm", ""}, - {"aten::_cufft_set_plan_cache_max_size", ""}, - {"aten::_cufft_clear_plan_cache", ""}, - {"aten::index", "Tensor"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::index_copy_", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::index_copy", "dimname"}, - #endif - {"aten::index_put_", ""}, - {"aten::index_put", ""}, - {"aten::_index_put_impl_", ""}, - {"aten::instance_norm", ""}, - {"aten::inverse", "out"}, - {"aten::kthvalue", "values"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::kthvalue", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::kthvalue", "dimname_out"}, - #endif - {"aten::layer_norm", ""}, - {"aten::native_layer_norm", ""}, - {"aten::native_layer_norm_backward", ""}, - {"aten::native_layer_norm_double_backward", ""}, - {"aten::linear", ""}, - {"aten::mkldnn_linear", ""}, - {"aten::linspace", ""}, - {"aten::linspace", "out"}, - {"aten::log", "out"}, - {"aten::log10", "out"}, - {"aten::log1p", "out"}, - {"aten::log2", "out"}, - {"aten::logspace", ""}, - {"aten::logspace", "out"}, - {"aten::log_softmax", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::log_softmax", ""}, - #endif - {"aten::logsumexp", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::logsumexp", "names"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::logsumexp", "names_out"}, - #endif - {"aten::matmul", "out"}, - {"aten::max", "dim_max"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::max", "names_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::max", "names_dim_max"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::max_values", "names"}, - #endif - {"aten::mean", ""}, - {"aten::mean", "dim"}, - {"aten::mean", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::mean", "names_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::mean", "names_out"}, - #endif - {"aten::median", "dim_values"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::median", "names_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::median", "names_dim_values"}, - #endif - {"aten::min", "dim_min"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::min", "names_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::min", "names_dim_min"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::min_values", "names"}, - #endif - {"aten::mkldnn_convolution", ""}, - {"aten::miopen_batch_norm", ""}, - {"aten::miopen_batch_norm_backward", ""}, - {"aten::miopen_convolution", ""}, - {"aten::miopen_convolution_transpose", ""}, - {"aten::miopen_depthwise_convolution", ""}, - {"aten::miopen_rnn", ""}, - {"aten::miopen_rnn_backward", ""}, - {"aten::mm", "out"}, - {"aten::mode", "values"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::mode", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::mode", "dimname_out"}, - #endif - {"aten::mul", "out"}, - {"aten::mv", "out"}, - {"aten::native_batch_norm", ""}, - {"aten::batch_norm_elemt", ""}, - {"aten::batch_norm_gather_stats", ""}, - {"aten::batch_norm_gather_stats_with_counts", ""}, - {"aten::native_batch_norm_backward", ""}, - {"aten::batch_norm_backward_reduce", ""}, - {"aten::batch_norm_backward_elemt", ""}, - {"aten::batch_norm_update_stats", ""}, - {"aten::_nnpack_spatial_convolution", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::ones", "names"}, - #endif - {"aten::ones", ""}, - {"aten::ones", "out"}, - {"aten::ones_like", "dtype"}, - {"aten::scalar_tensor", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::rand", "names"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::rand", "generator_with_names"}, - #endif - {"aten::rand", ""}, - {"aten::rand", "generator"}, - {"aten::rand", "out"}, - {"aten::rand", "generator_out"}, - {"aten::rand_like", "dtype"}, - {"aten::randint", ""}, - {"aten::randint", "generator"}, - {"aten::randint", "low"}, - {"aten::randint", "low_generator"}, - {"aten::randint", "out"}, - {"aten::randint", "generator_out"}, - {"aten::randint", "low_out"}, - {"aten::randint", "low_generator_out"}, - {"aten::randint_like", "dtype"}, - {"aten::randint_like", "low_dtype"}, - {"aten::randn", ""}, - {"aten::randn", "generator"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::randn", "names"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::randn", "generator_with_names"}, - #endif - {"aten::randn", "out"}, - {"aten::randn", "generator_out"}, - {"aten::randn_like", "dtype"}, - {"aten::randperm", ""}, - {"aten::randperm", "generator"}, - {"aten::randperm", "out"}, - {"aten::randperm", "generator_out"}, - {"aten::range", "step"}, - {"aten::range", ""}, - {"aten::range", "out"}, - {"aten::reciprocal", "out"}, - {"aten::neg", "out"}, - {"aten::round", "out"}, - {"aten::rsqrt", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::select", "Dimname"}, - #endif - {"aten::sigmoid", "out"}, - {"aten::sin", "out"}, - {"aten::sinh", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::size", "Dimname"}, - #endif - {"aten::softmax", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::softmax", ""}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::squeeze", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::squeeze_", "dimname"}, - #endif - {"aten::sspaddmm", "out"}, - {"aten::stack", "out"}, - {"aten::stft", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::stride", "Dimname"}, - #endif - {"aten::sum", ""}, - {"aten::sum", "dim_IntList"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::sum", "dim_DimnameList"}, - #endif - {"aten::sum", "IntList_out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::sum", "DimnameList_out"}, - #endif - {"aten::sqrt", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::std_mean", "names_dim"}, - #endif - {"aten::std", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::std", "names_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::std", "names_out"}, - #endif - {"aten::prod", ""}, - {"aten::prod", "dim_int"}, - {"aten::prod", "int_out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::prod", "dim_Dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::prod", "Dimname_out"}, - #endif - {"aten::tan", "out"}, - {"aten::tanh", "out"}, - {"aten::threshold", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::transpose", "Dimname"}, - #endif - {"aten::trunc", "out"}, - {"aten::var", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::var", "names_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::var", "names_out"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::var_mean", "names_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::zeros", "names"}, - #endif - {"aten::zeros", ""}, - {"aten::zeros", "out"}, - {"aten::zeros_like", "dtype"}, - {"aten::_sparse_sum", "dtype"}, - {"aten::_sparse_sum", "dim_dtype"}, - {"aten::norm", "ScalarOpt_dtype"}, - {"aten::norm", "ScalarOpt_dim_dtype"}, - {"aten::norm", "dtype_out"}, - {"aten::norm", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::norm", "names_ScalarOpt_dim_dtype"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::norm", "names_ScalarOpt_dim"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::norm", "names_dtype_out"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::norm", "names_out"}, - #endif - {"aten::frobenius_norm", "out"}, - {"aten::nuclear_norm", "out"}, - {"aten::nuclear_norm", "dim_out"}, - {"aten::pow", "Tensor_Scalar_out"}, - {"aten::sub", "out"}, - {"aten::addmm", "out"}, - {"aten::sparse_coo_tensor", "size"}, - {"aten::sparse_coo_tensor", "indices"}, - {"aten::sparse_coo_tensor", "indices_size"}, - {"aten::_sparse_coo_tensor_unsafe", ""}, - {"aten::_sparse_coo_tensor_with_dims", ""}, - {"aten::_sparse_coo_tensor_with_dims_and_tensors", ""}, - {"aten::hspmm", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::unbind", "Dimname"}, - #endif - {"aten::quantize_per_tensor", ""}, - {"aten::quantize_per_channel", ""}, - {"aten::q_per_channel_axis", ""}, - {"aten::qscheme", ""}, - {"aten::fake_quantize_per_channel_affine", ""}, - {"aten::fake_quantize_per_channel_affine_backward", ""}, - {"aten::to", "dtype_layout"}, - {"aten::to", "device"}, - {"aten::to", "dtype"}, - {"aten::result_type", "Tensor"}, - {"aten::result_type", "Scalar"}, - {"aten::result_type", "Scalar_Tensor"}, - {"aten::result_type", "Scalar_Scalar"}, - {"aten::can_cast", ""}, - {"aten::promote_types", ""}, - {"aten::_thnn_fused_lstm_cell", ""}, - {"aten::_thnn_fused_lstm_cell_backward", ""}, - {"aten::_thnn_differentiable_lstm_cell_backward", ""}, - {"aten::_thnn_fused_gru_cell", ""}, - {"aten::_thnn_differentiable_gru_cell_backward", ""}, - {"aten::lstm_cell", ""}, - {"aten::gru_cell", ""}, - {"aten::rnn_tanh_cell", ""}, - {"aten::rnn_relu_cell", ""}, - {"aten::quantized_lstm", ""}, - {"aten::set_", "source_Storage"}, - {"aten::set_", "source_Storage_storage_offset"}, - {"aten::set_quantizer_", ""}, - #ifdef BUILD_NAMEDTENSOR - {"aten::index_add", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::index_fill_", "dimname_Scalar"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::index_fill_", "dimname_Scalar"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::index_fill", "dimname_Scalar"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::index_fill", "dimname_Tensor"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::scatter", "dimname_src"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::scatter", "dimname_value"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::scatter_add", "dimname"}, - #endif - {"aten::addbmm", "out"}, - {"aten::diag", "out"}, - {"aten::cross", "out"}, - {"aten::triu", "out"}, - {"aten::tril", "out"}, - {"aten::tril_indices", ""}, - {"aten::triu_indices", ""}, - {"aten::ne", "Scalar_out"}, - {"aten::ne", "Tensor_out"}, - {"aten::eq", "Scalar_out"}, - {"aten::eq", "Tensor_out"}, - {"aten::ge", "Scalar_out"}, - {"aten::ge", "Tensor_out"}, - {"aten::le", "Scalar_out"}, - {"aten::le", "Tensor_out"}, - {"aten::gt", "Scalar_out"}, - {"aten::gt", "Tensor_out"}, - {"aten::lt", "Scalar_out"}, - {"aten::lt", "Tensor_out"}, - {"aten::take", "out"}, - {"aten::index_select", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::index_select", "dimname_out"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::index_select", "dimname"}, - #endif - {"aten::masked_select", "out"}, - {"aten::nonzero", "out"}, - {"aten::gather", "out"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::gather", "dimname_out"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::gather", "dimname"}, - #endif - {"aten::addcmul", "out"}, - {"aten::addcdiv", "out"}, - {"aten::lstsq", "X"}, - {"aten::triangular_solve", "X"}, - {"aten::symeig", "e"}, - {"aten::eig", "e"}, - {"aten::svd", "U"}, - {"aten::cholesky", "out"}, - {"aten::cholesky_solve", "out"}, - {"aten::solve", "solution"}, - {"aten::cholesky_inverse", "out"}, - {"aten::qr", "Q"}, - {"aten::geqrf", "a"}, - {"aten::orgqr", "out"}, - {"aten::ormqr", "out"}, - {"aten::lu_solve", "out"}, - {"aten::multinomial", "out"}, - {"aten::lgamma", "out"}, - {"aten::digamma", "out"}, - {"aten::polygamma", "out"}, - {"aten::erfinv", "out"}, - {"aten::sign", "out"}, - {"aten::atan2", "out"}, - {"aten::lerp", "Scalar_out"}, - {"aten::lerp", "Tensor_out"}, - {"aten::histc", "out"}, - {"aten::fmod", "Scalar_out"}, - {"aten::fmod", "Tensor_out"}, - {"aten::remainder", "Scalar_out"}, - {"aten::remainder", "Tensor_out"}, - {"aten::min", "out"}, - {"aten::max", "out"}, - {"aten::sort", "values"}, - #ifdef BUILD_NAMEDTENSOR - {"aten::sort", "dimname_values"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::sort", "dimname"}, - #endif - #ifdef BUILD_NAMEDTENSOR - {"aten::argsort", "dimname"}, - #endif - {"aten::topk", "values"}, - {"aten::renorm", "out"}, - {"aten::pow", "Tensor_Tensor_out"}, - {"aten::pow", "Scalar_out"}, - {"aten::normal", "Tensor_float_out"}, - {"aten::normal", "float_Tensor_out"}, - {"aten::normal", "Tensor_Tensor_out"}, - {"aten::normal", "float_float"}, - {"aten::normal", "float_float_out"}, - {"aten::_addr", "out"}, - {"aten::_cumsum", "out"}, - {"aten::_cumprod", "out"}, - {"aten::_cat", "out"}, - {"aten::_mode", "values"}, - {"aten::_max", "max"}, - {"aten::_min", "min"}, - {"aten::binary_cross_entropy", "out"}, - {"aten::binary_cross_entropy", ""}, - {"aten::binary_cross_entropy_backward", "grad_input"}, - {"aten::binary_cross_entropy_backward", ""}, - {"aten::mse_loss", "out"}, - {"aten::mse_loss_backward", "grad_input"}, - {"aten::l1_loss", "out"}, - {"aten::l1_loss_backward", "grad_input"}, - {"aten::multi_margin_loss", "out"}, - {"aten::multi_margin_loss", ""}, - {"aten::multi_margin_loss_backward", "grad_input"}, - {"aten::multi_margin_loss_backward", ""}, - {"aten::multilabel_margin_loss", "out"}, - {"aten::multilabel_margin_loss_forward", "output"}, - {"aten::multilabel_margin_loss_backward", "grad_input"}, - {"aten::nll_loss", "out"}, - {"aten::nll_loss", ""}, - {"aten::nll_loss_forward", "output"}, - {"aten::nll_loss_forward", ""}, - {"aten::nll_loss_backward", "grad_input"}, - {"aten::nll_loss_backward", ""}, - {"aten::nll_loss2d", "out"}, - {"aten::nll_loss2d", ""}, - {"aten::nll_loss2d_forward", "output"}, - {"aten::nll_loss2d_forward", ""}, - {"aten::nll_loss2d_backward", "grad_input"}, - {"aten::nll_loss2d_backward", ""}, - {"aten::smooth_l1_loss", "out"}, - {"aten::smooth_l1_loss_backward", "grad_input"}, - {"aten::soft_margin_loss", "out"}, - {"aten::soft_margin_loss_backward", "grad_input"}, - {"aten::elu", "out"}, - {"aten::elu_backward", "grad_input"}, - {"aten::glu", "out"}, - {"aten::glu_backward", "grad_input"}, - {"aten::hardtanh", "out"}, - {"aten::hardtanh_backward", "grad_input"}, - {"aten::leaky_relu", "out"}, - {"aten::leaky_relu_backward", "grad_input"}, - {"aten::log_sigmoid", "out"}, - {"aten::log_sigmoid_forward", "output"}, - {"aten::log_sigmoid_backward", "grad_input"}, - {"aten::rrelu_with_noise", "out"}, - {"aten::rrelu_with_noise_backward", "grad_input"}, - {"aten::softplus", "out"}, - {"aten::softplus_backward", "grad_input"}, - {"aten::softshrink", "out"}, - {"aten::softshrink_backward", "grad_input"}, - {"aten::adaptive_avg_pool2d", "out"}, - {"aten::adaptive_avg_pool3d", "out"}, - {"aten::adaptive_avg_pool3d_backward", "grad_input"}, - {"aten::adaptive_max_pool2d", "out"}, - {"aten::adaptive_max_pool2d_backward", "grad_input"}, - {"aten::adaptive_max_pool3d", "out"}, - {"aten::adaptive_max_pool3d_backward", "grad_input"}, - {"aten::avg_pool2d", "out"}, - {"aten::avg_pool2d_backward", "grad_input"}, - {"aten::avg_pool3d", "out"}, - {"aten::avg_pool3d_backward", "grad_input"}, - {"aten::fractional_max_pool2d", "output"}, - {"aten::fractional_max_pool2d_backward", "grad_input"}, - {"aten::fractional_max_pool3d", "output"}, - {"aten::fractional_max_pool3d_backward", "grad_input"}, - {"aten::max_pool2d_with_indices", "out"}, - {"aten::max_pool2d_with_indices_backward", "grad_input"}, - {"aten::max_pool3d_with_indices", "out"}, - {"aten::max_pool3d_with_indices_backward", "grad_input"}, - {"aten::max_unpool2d", "out"}, - {"aten::max_unpool2d_backward", "grad_input"}, - {"aten::max_unpool3d", "out"}, - {"aten::max_unpool3d_backward", "grad_input"}, - {"aten::reflection_pad1d", "out"}, - {"aten::reflection_pad1d_backward", "grad_input"}, - {"aten::reflection_pad2d", "out"}, - {"aten::reflection_pad2d_backward", "grad_input"}, - {"aten::replication_pad1d", "out"}, - {"aten::replication_pad1d_backward", "grad_input"}, - {"aten::replication_pad2d", "out"}, - {"aten::replication_pad2d_backward", "grad_input"}, - {"aten::replication_pad3d", "out"}, - {"aten::replication_pad3d_backward", "grad_input"}, - {"aten::upsample_linear1d", "out"}, - {"aten::upsample_linear1d_backward", "grad_input"}, - {"aten::upsample_bilinear2d", "out"}, - {"aten::upsample_bilinear2d_backward", "grad_input"}, - {"aten::upsample_bicubic2d", "out"}, - {"aten::upsample_bicubic2d_backward", "grad_input"}, - {"aten::upsample_trilinear3d", "out"}, - {"aten::upsample_trilinear3d_backward", "grad_input"}, - {"aten::upsample_nearest1d", "out"}, - {"aten::upsample_nearest1d_backward", "grad_input"}, - {"aten::upsample_nearest2d", "out"}, - {"aten::upsample_nearest2d_backward", "grad_input"}, - {"aten::upsample_nearest3d", "out"}, - {"aten::upsample_nearest3d_backward", "grad_input"}, - {"aten::sigmoid_backward", "grad_input"}, - {"aten::tanh_backward", "grad_input"}, - {"aten::slow_conv_transpose2d", "out"}, - {"aten::slow_conv_transpose2d", ""}, - {"aten::slow_conv_transpose2d_backward", "grad_output"}, - {"aten::slow_conv_transpose3d", "out"}, - {"aten::slow_conv_transpose3d", ""}, - {"aten::slow_conv_transpose3d_backward", "grad_output"}, - {"aten::thnn_conv2d", "out"}, - {"aten::thnn_conv2d", ""}, - {"aten::thnn_conv2d_forward", "output"}, - {"aten::thnn_conv2d_forward", ""}, - {"aten::thnn_conv2d_backward", "grad_input"}, - {"aten::thnn_conv_depthwise2d", "out"}, - {"aten::thnn_conv_depthwise2d", ""}, - {"aten::thnn_conv_depthwise2d_forward", "out"}, - {"aten::thnn_conv_depthwise2d_forward", ""}, - {"aten::thnn_conv_depthwise2d_backward", "grad_input"}, - {"aten::thnn_conv3d", "out"}, - {"aten::thnn_conv3d", ""}, - {"aten::thnn_conv3d_forward", "output"}, - {"aten::thnn_conv3d_forward", ""}, - {"aten::thnn_conv3d_backward", "grad_input"}, - {"aten::slow_conv_dilated2d", ""}, - {"aten::slow_conv_dilated3d", ""}, - {"aten::col2im", "out"}, - {"aten::col2im_backward", "grad_input"}, - {"aten::im2col", "out"}, - {"aten::im2col_backward", "grad_input"}, - {"", ""} - }; - return ops.count(std::make_pair(opName.name.c_str(), opName.overload_name.c_str())) != 0; -} - -} diff --git a/aten/src/ATen/core/PhiloxRNGEngine.h b/aten/src/ATen/core/PhiloxRNGEngine.h index 9a283f1d06d66..1e597e8fc1f64 100644 --- a/aten/src/ATen/core/PhiloxRNGEngine.h +++ b/aten/src/ATen/core/PhiloxRNGEngine.h @@ -180,15 +180,15 @@ class philox_engine { #endif } - C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 key) { + C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) { uint32_t hi0; uint32_t hi1; uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0); uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1); detail::UINT4 ret; - ret[0] = hi1 ^ ctr[1] ^ key[0]; + ret[0] = hi1 ^ ctr[1] ^ in_key[0]; ret[1] = lo1; - ret[2] = hi0 ^ ctr[3] ^ key[1]; + ret[2] = hi0 ^ ctr[3] ^ in_key[1]; ret[3] = lo0; return ret; } diff --git a/aten/src/ATen/core/TensorBody.h b/aten/src/ATen/core/TensorBody.h deleted file mode 100644 index bdda41eff145d..0000000000000 --- a/aten/src/ATen/core/TensorBody.h +++ /dev/null @@ -1,1017 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace caffe2 { -class Tensor; -} -namespace c10{ -struct TensorOptions; -} -namespace at { -struct Generator; -struct Type; -class DeprecatedTypeProperties; -class Tensor; -} // namespace at - -namespace at { - -class Tensor; -using TensorList = ArrayRef; - -struct Quantizer; -// This is temporary typedef to enable Quantizer in aten native function API -// we'll remove them when we are actually exposing Quantizer class -// to frontend -using QuantizerPtr = c10::intrusive_ptr; -using ConstQuantizerPtr = const c10::intrusive_ptr&; - -// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which -// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr. -// -// For example: -// -// void func(Tensor a) { -// Tensor b = a; -// ... -// } -// -// In this example, when we say Tensor b = a, we are creating a new object that points to the -// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the -// destructor decrements the reference count by calling release() on the TensorImpl it points to. -// The existing constructors, operator overloads, etc. take care to implement the correct semantics. -// -// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and -// special care must be taken to handle this. -class CAFFE2_API Tensor { - public: - Tensor(){}; - // This constructor should not be used by end users and is an implementation - // detail invoked by autogenerated code. - explicit Tensor( - c10::intrusive_ptr tensor_impl) - : impl_(std::move(tensor_impl)) { - if (impl_.get() == nullptr) { - throw std::runtime_error("TensorImpl with nullptr is not supported"); - } - } - Tensor(const Tensor&) = default; - Tensor(Tensor&&) = default; - - - public: - // Creates a new wrapper from TensorImpl. Intentionally a free method because - // it should be used with care. Checks necessary invariants - static Tensor wrap_tensor_impl( - c10::intrusive_ptr tensor_impl) { - Tensor r(std::move(tensor_impl)); - r.enforce_invariants(); - return r; - } - - int64_t dim() const { - return impl_->dim(); - } - int64_t storage_offset() const { - return impl_->storage_offset(); - } - - TensorImpl * unsafeGetTensorImpl() const { - return impl_.get(); - } - TensorImpl * unsafeReleaseTensorImpl() { - return impl_.release(); - } - const c10::intrusive_ptr& getIntrusivePtr() const { - return impl_; - } - - bool defined() const { - return impl_; - } - - void reset() { - impl_.reset(); - } - - // The following overloads are very intruiging. Consider the following - // program: - // - // x[1] = 3; - // - // We would expect that the first entry of x is written to 3. But how can we - // actually achieve this? x[1] evaluates to a tensor... - // - // The answer is, using a ref-qualifier. x[1] is an rvalue, which cannot be - // (profitably) assigned to in the traditional sense, so we overload - // assignment to mean, "Actually, copy 3 into the tensor data." This is done - // with an rvalue-reference ref-qualified overload (the methods with && at the - // end of their type.) - // - // There's one more fly in the ointment: We also want - // - // Tensor x = y; - // - // to work, and we want it NOT to copy. So we need a traditional operator= - // overload. But we MUST specify a mutable lvalue ref-qualifier, to - // disambiguate the traditional overload from the rvalue-reference - // ref-qualified overload. Otherwise, it will be ambiguous, because - // a non ref-qualified method is eligible for all situations. - - // Unfortunately, we have to write these constructors out manually - // to work around an MSVC bug: - // error C2580: 'at::Tensor &at::Tensor::operator =(const at::Tensor &) &': - // multiple versions of a defaulted special member functions are not allowed - // Tensor& operator=(const Tensor&) & = default; - // Tensor& operator=(Tensor&&) & = default; - Tensor& operator=(const Tensor& x) & { - impl_ = x.impl_; - return *this; - } - Tensor& operator=(Tensor&& x) & { - impl_ = std::move(x.impl_); - return *this; - } - - Tensor& operator=(Scalar v) &&; - Tensor& operator=(const Tensor&) &&; - Tensor& operator=(Tensor&&) &&; - - bool is_same(const Tensor& other) const noexcept { - return impl_ == other.impl_; - } - size_t use_count() const noexcept { - return impl_.use_count(); - } - size_t weak_use_count() const noexcept { - return impl_.weak_use_count(); - } - - std::string toString() const; - - IntArrayRef sizes() const { - return impl_->sizes(); - } - IntArrayRef strides() const { - return impl_->strides(); - } -#ifdef BUILD_NAMEDTENSOR - // See impl::get_opt_names in ATen/NamedTensor.h for docs. - optional opt_names() const { - return impl::get_opt_names(unsafeGetTensorImpl()); - } - // See impl::get_names in ATen/NamedTensor.h for docs. - DimnameList names() const { - return impl::get_names(unsafeGetTensorImpl()); - } -#endif - int64_t ndimension() const { - return dim(); - } - bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { - return impl_->is_contiguous(memory_format); - } - - at::MemoryFormat suggest_memory_format() const { - if (impl_->is_strides_like_channels_last()) { - return at::MemoryFormat::ChannelsLast; - } - return at::MemoryFormat::Contiguous; - } - - // Total bytes consumed by the "view" of elements of the array. Does not - // include size of metadata. The number reported here does not necessarily - // correspond to the true physical memory consumed by a tensor; instead, - // it reports the memory the tensor would take *if* it were contiguous. - // Defined to be numel() * itemsize() - size_t nbytes() const { - return impl_->numel() * impl_->itemsize(); - } - - // Length of one array element in bytes. This is the traditional - // Numpy naming. - size_t itemsize() const { - return impl_->itemsize(); - } - - // Same as itemsize(). This is the PyTorch naming. - size_t element_size() const { - return impl_->itemsize(); - } - - DeprecatedTypeProperties & type() const { - return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( - tensorTypeIdToBackend(legacyExtractTypeId(type_set())), - scalar_type(), - is_variable()); - } - TensorTypeSet type_set() const { - return impl_->type_set(); - } - ScalarType scalar_type() const { - return typeMetaToScalarType(impl_->dtype()); - } - bool has_storage() const { - return defined() && impl_->has_storage(); - } - const Storage& storage() const { - return impl_->storage(); - } - bool is_alias_of(const at::Tensor& other) const{ - return impl_->storage().is_alias_of(other.storage()); - } - Tensor toType(ScalarType t) const; - Tensor toBackend(Backend b) const; - - /// Returns true if the `Tensor` is actually a `torch::autograd::Variable`. - /// Defined in Type.h because of include order issues. - bool is_variable() const noexcept { - return impl_->is_variable(); - } - - /// Returns a `Tensor`'s layout. Defined in Type.h - Layout layout() const noexcept; - - /// Returns a `Tensor`'s dtype (`TypeMeta`). Defined in TensorMethods.h - caffe2::TypeMeta dtype() const noexcept; - - /// Returns a `Tensor`'s device. - Device device() const; - - /// Returns a `Tensor`'s device index. - int64_t get_device() const; - - /// Returns if a `Tensor` has CUDA backend. - bool is_cuda() const; - - /// Returns if a `Tensor` has HIP backend. - bool is_hip() const; - - /// Returns if a `Tensor` has sparse backend. - bool is_sparse() const; - - /// Returns if a `Tensor` is mkldnn tensor. - bool is_mkldnn() const; - - /// Returns if a `Tensor` has quantized backend. - bool is_quantized() const; - - /// If a tensor is a quantized tensor, returns its quantizer - /// TODO: it's not in native_functions.yaml yet as it's not exposed to python - QuantizerPtr quantizer() const; - -#ifdef BUILD_NAMEDTENSOR - /// Returns if a `Tensor` has any dimension names - bool has_names() const; - - /// Returns a `Tensor`'s dimension names data structure - const NamedTensorMeta* get_named_tensor_meta() const; - NamedTensorMeta* get_named_tensor_meta(); -#endif - - /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in - /// TensorOptions.h. - TensorOptions options() const; - - void* data_ptr() const { - return this->unsafeGetTensorImpl()->data(); - } - - template - T * data_ptr() const; - - template - C10_DEPRECATED_MESSAGE("Tensor.data() is deprecated. Please use Tensor.data_ptr() instead.") - T * data() const { - return data_ptr(); - } - - template - T item() const; - - // Purposely not defined here to avoid inlining - void print() const; - - // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and - // dimension. - template - TensorAccessor accessor() const& { - static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); - TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return TensorAccessor(data_ptr(),sizes().data(),strides().data()); - } - template - TensorAccessor accessor() && = delete; - - // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and - // dimension. You can optionally specify RestrictPtrTraits as a template parameter to - // cast the data pointer to a __restrict__ pointer. - // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor - // as an argument. - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - GenericPackedTensorAccessor generic_packed_accessor() const& { - static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); - TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); - } - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - GenericPackedTensorAccessor generic_packed_accessor() && = delete; - - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor32 packed_accessor32() const& { - return generic_packed_accessor(); - } - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor32 packed_accessor32() && = delete; - - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor64 packed_accessor64() const& { - return generic_packed_accessor(); - } - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor64 packed_accessor64() && = delete; - - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") - GenericPackedTensorAccessor packed_accessor() const & { - return generic_packed_accessor(); - } - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") - GenericPackedTensorAccessor packed_accessor() && = delete; - - Tensor operator-() const; - Tensor& operator+=(const Tensor & other); - Tensor& operator+=(Scalar other); - Tensor& operator-=(const Tensor & other); - Tensor& operator-=(Scalar other); - Tensor& operator*=(const Tensor & other); - Tensor& operator*=(Scalar other); - Tensor& operator/=(const Tensor & other); - Tensor& operator/=(Scalar other); - Tensor operator[](Scalar index) const; - Tensor operator[](Tensor index) const; - Tensor operator[](int64_t index) const; - - Tensor cpu() const; - Tensor cuda() const; - Tensor hip() const; - - // ~~~~~ Autograd API ~~~~~ - - Tensor& set_requires_grad(bool requires_grad) { - impl_->set_requires_grad(requires_grad); - return *this; - } - bool requires_grad() const { - return impl_->requires_grad(); - } - - Tensor& grad() { - return impl_->grad(); - } - const Tensor& grad() const { - return impl_->grad(); - } - - // STOP. Thinking of adding a method here, which only makes use - // of other ATen methods? Define it in native_functions.yaml. - - //example - //Tensor * add(Tensor & b); - void backward(const Tensor & gradient={}, bool keep_graph=false, bool create_graph=false) const; - void set_data(const Tensor & new_data) const; - Tensor data() const; - bool is_leaf() const; - int64_t output_nr() const; - int64_t _version() const; - #ifdef BUILD_NAMEDTENSOR - Tensor & rename_(c10::optional names) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor rename(c10::optional names) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor align_to(DimnameList names) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor align_as(const Tensor & other) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor refine_names(DimnameList names) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor unflatten(Dimname dim, IntArrayRef sizes, DimnameList names) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor unflatten(int64_t dim, IntArrayRef sizes, DimnameList names) const; - #endif - Tensor abs() const; - Tensor & abs_() const; - Tensor acos() const; - Tensor & acos_() const; - Tensor add(const Tensor & other, Scalar alpha=1) const; - Tensor & add_(const Tensor & other, Scalar alpha=1) const; - Tensor add(Scalar other, Scalar alpha=1) const; - Tensor & add_(Scalar other, Scalar alpha=1) const; - Tensor addmv(const Tensor & mat, const Tensor & vec, Scalar beta=1, Scalar alpha=1) const; - Tensor & addmv_(const Tensor & mat, const Tensor & vec, Scalar beta=1, Scalar alpha=1) const; - Tensor addr(const Tensor & vec1, const Tensor & vec2, Scalar beta=1, Scalar alpha=1) const; - Tensor & addr_(const Tensor & vec1, const Tensor & vec2, Scalar beta=1, Scalar alpha=1) const; - Tensor all(int64_t dim, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor all(Dimname dim, bool keepdim=false) const; - #endif - bool allclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; - Tensor any(int64_t dim, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor any(Dimname dim, bool keepdim=false) const; - #endif - Tensor argmax(c10::optional dim=c10::nullopt, bool keepdim=false) const; - Tensor argmin(c10::optional dim=c10::nullopt, bool keepdim=false) const; - Tensor as_strided(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset=c10::nullopt) const; - Tensor & as_strided_(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset=c10::nullopt) const; - Tensor asin() const; - Tensor & asin_() const; - Tensor atan() const; - Tensor & atan_() const; - Tensor baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const; - Tensor & baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const; - Tensor bernoulli(Generator * generator=nullptr) const; - Tensor & bernoulli_(const Tensor & p, Generator * generator=nullptr) const; - Tensor & bernoulli_(double p=0.5, Generator * generator=nullptr) const; - Tensor bernoulli(double p, Generator * generator=nullptr) const; - Tensor bincount(const Tensor & weights={}, int64_t minlength=0) const; - Tensor bitwise_not() const; - Tensor & bitwise_not_() const; - Tensor logical_not() const; - Tensor & logical_not_() const; - Tensor logical_xor(const Tensor & other) const; - Tensor & logical_xor_(const Tensor & other) const; - Tensor bmm(const Tensor & mat2) const; - Tensor ceil() const; - Tensor & ceil_() const; - std::vector chunk(int64_t chunks, int64_t dim=0) const; - Tensor clamp(c10::optional min=c10::nullopt, c10::optional max=c10::nullopt) const; - Tensor & clamp_(c10::optional min=c10::nullopt, c10::optional max=c10::nullopt) const; - Tensor clamp_max(Scalar max) const; - Tensor & clamp_max_(Scalar max) const; - Tensor clamp_min(Scalar min) const; - Tensor & clamp_min_(Scalar min) const; - Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const; - Tensor & copy_(const Tensor & src, bool non_blocking=false) const; - Tensor cos() const; - Tensor & cos_() const; - Tensor cosh() const; - Tensor & cosh_() const; - Tensor cumsum(int64_t dim, c10::optional dtype=c10::nullopt) const; - #ifdef BUILD_NAMEDTENSOR - Tensor cumsum(Dimname dim, c10::optional dtype=c10::nullopt) const; - #endif - Tensor cumprod(int64_t dim, c10::optional dtype=c10::nullopt) const; - #ifdef BUILD_NAMEDTENSOR - Tensor cumprod(Dimname dim, c10::optional dtype=c10::nullopt) const; - #endif - Tensor det() const; - Tensor diag_embed(int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) const; - Tensor diagflat(int64_t offset=0) const; - Tensor diagonal(int64_t offset=0, int64_t dim1=0, int64_t dim2=1) const; - Tensor & fill_diagonal_(Scalar fill_value, bool wrap=false) const; - Tensor div(const Tensor & other) const; - Tensor & div_(const Tensor & other) const; - Tensor div(Scalar other) const; - Tensor & div_(Scalar other) const; - Tensor dot(const Tensor & tensor) const; - Tensor new_empty(IntArrayRef size, const TensorOptions & options={}) const; - Tensor new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options={}) const; - Tensor & resize_(IntArrayRef size) const; - Tensor erf() const; - Tensor & erf_() const; - Tensor erfc() const; - Tensor & erfc_() const; - Tensor exp() const; - Tensor & exp_() const; - Tensor expm1() const; - Tensor & expm1_() const; - Tensor expand(IntArrayRef size, bool implicit=false) const; - Tensor expand_as(const Tensor & other) const; - Tensor flatten(int64_t start_dim=0, int64_t end_dim=-1) const; - #ifdef BUILD_NAMEDTENSOR - Tensor flatten(int64_t start_dim, int64_t end_dim, Dimname out_dim) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor flatten(Dimname start_dim, Dimname end_dim, Dimname out_dim) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor flatten(DimnameList dims, Dimname out_dim) const; - #endif - Tensor & fill_(Scalar value) const; - Tensor & fill_(const Tensor & value) const; - Tensor floor() const; - Tensor & floor_() const; - Tensor frac() const; - Tensor & frac_() const; - Tensor ger(const Tensor & vec2) const; - Tensor fft(int64_t signal_ndim, bool normalized=false) const; - Tensor ifft(int64_t signal_ndim, bool normalized=false) const; - Tensor rfft(int64_t signal_ndim, bool normalized=false, bool onesided=true) const; - Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntArrayRef signal_sizes={}) const; - Tensor index(TensorList indices) const; - Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source) const; - Tensor index_copy(int64_t dim, const Tensor & index, const Tensor & source) const; - #ifdef BUILD_NAMEDTENSOR - Tensor & index_copy_(Dimname dim, const Tensor & index, const Tensor & source) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor index_copy(Dimname dim, const Tensor & index, const Tensor & source) const; - #endif - Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false) const; - Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const; - Tensor inverse() const; - Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const; - bool is_distributed() const; - bool is_floating_point() const; - bool is_complex() const; - bool is_nonzero() const; - bool is_same_size(const Tensor & other) const; - bool is_signed() const; - std::tuple kthvalue(int64_t k, int64_t dim=-1, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - std::tuple kthvalue(int64_t k, Dimname dim, bool keepdim=false) const; - #endif - Tensor log() const; - Tensor & log_() const; - Tensor log10() const; - Tensor & log10_() const; - Tensor log1p() const; - Tensor & log1p_() const; - Tensor log2() const; - Tensor & log2_() const; - Tensor logdet() const; - Tensor log_softmax(int64_t dim, c10::optional dtype=c10::nullopt) const; - #ifdef BUILD_NAMEDTENSOR - Tensor log_softmax(Dimname dim, c10::optional dtype=c10::nullopt) const; - #endif - Tensor logsumexp(IntArrayRef dim, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor logsumexp(DimnameList dim, bool keepdim=false) const; - #endif - Tensor matmul(const Tensor & other) const; - Tensor matrix_power(int64_t n) const; - std::tuple max(int64_t dim, bool keepdim=false) const; - Tensor max_values(IntArrayRef dim, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - std::tuple max(Dimname dim, bool keepdim=false) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor max_values(DimnameList dim, bool keepdim=false) const; - #endif - Tensor mean(c10::optional dtype=c10::nullopt) const; - Tensor mean(IntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const; - #ifdef BUILD_NAMEDTENSOR - Tensor mean(DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const; - #endif - std::tuple median(int64_t dim, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - std::tuple median(Dimname dim, bool keepdim=false) const; - #endif - std::tuple min(int64_t dim, bool keepdim=false) const; - Tensor min_values(IntArrayRef dim, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - std::tuple min(Dimname dim, bool keepdim=false) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor min_values(DimnameList dim, bool keepdim=false) const; - #endif - Tensor mm(const Tensor & mat2) const; - std::tuple mode(int64_t dim=-1, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - std::tuple mode(Dimname dim, bool keepdim=false) const; - #endif - Tensor mul(const Tensor & other) const; - Tensor & mul_(const Tensor & other) const; - Tensor mul(Scalar other) const; - Tensor & mul_(Scalar other) const; - Tensor mv(const Tensor & vec) const; - Tensor mvlgamma(int64_t p) const; - Tensor & mvlgamma_(int64_t p) const; - Tensor narrow_copy(int64_t dim, int64_t start, int64_t length) const; - Tensor narrow(int64_t dim, int64_t start, int64_t length) const; - Tensor permute(IntArrayRef dims) const; - Tensor numpy_T() const; - bool is_pinned() const; - Tensor pin_memory() const; - Tensor pinverse(double rcond=1e-15) const; - Tensor reciprocal() const; - Tensor & reciprocal_() const; - Tensor neg() const; - Tensor & neg_() const; - Tensor repeat(IntArrayRef repeats) const; - Tensor repeat_interleave(const Tensor & repeats, c10::optional dim=c10::nullopt) const; - Tensor repeat_interleave(int64_t repeats, c10::optional dim=c10::nullopt) const; - Tensor reshape(IntArrayRef shape) const; - Tensor reshape_as(const Tensor & other) const; - Tensor round() const; - Tensor & round_() const; - Tensor relu() const; - Tensor & relu_() const; - Tensor prelu(const Tensor & weight) const; - std::tuple prelu_backward(const Tensor & grad_output, const Tensor & weight) const; - Tensor hardshrink(Scalar lambd=0.5) const; - Tensor hardshrink_backward(const Tensor & grad_out, Scalar lambd) const; - Tensor rsqrt() const; - Tensor & rsqrt_() const; - #ifdef BUILD_NAMEDTENSOR - Tensor select(Dimname dim, int64_t index) const; - #endif - Tensor select(int64_t dim, int64_t index) const; - Tensor sigmoid() const; - Tensor & sigmoid_() const; - Tensor sin() const; - Tensor & sin_() const; - Tensor sinh() const; - Tensor & sinh_() const; - Tensor detach() const; - Tensor & detach_() const; - int64_t size(int64_t dim) const; - #ifdef BUILD_NAMEDTENSOR - int64_t size(Dimname dim) const; - #endif - Tensor slice(int64_t dim=0, int64_t start=0, int64_t end=9223372036854775807, int64_t step=1) const; - std::tuple slogdet() const; - Tensor smm(const Tensor & mat2) const; - Tensor softmax(int64_t dim, c10::optional dtype=c10::nullopt) const; - #ifdef BUILD_NAMEDTENSOR - Tensor softmax(Dimname dim, c10::optional dtype=c10::nullopt) const; - #endif - std::vector split(int64_t split_size, int64_t dim=0) const; - std::vector split_with_sizes(IntArrayRef split_sizes, int64_t dim=0) const; - Tensor squeeze() const; - Tensor squeeze(int64_t dim) const; - #ifdef BUILD_NAMEDTENSOR - Tensor squeeze(Dimname dim) const; - #endif - Tensor & squeeze_() const; - Tensor & squeeze_(int64_t dim) const; - #ifdef BUILD_NAMEDTENSOR - Tensor & squeeze_(Dimname dim) const; - #endif - Tensor sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const; - Tensor stft(int64_t n_fft, c10::optional hop_length=c10::nullopt, c10::optional win_length=c10::nullopt, const Tensor & window={}, bool normalized=false, bool onesided=true) const; - int64_t stride(int64_t dim) const; - #ifdef BUILD_NAMEDTENSOR - int64_t stride(Dimname dim) const; - #endif - Tensor sum(c10::optional dtype=c10::nullopt) const; - Tensor sum(IntArrayRef dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const; - #ifdef BUILD_NAMEDTENSOR - Tensor sum(DimnameList dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const; - #endif - Tensor sum_to_size(IntArrayRef size) const; - Tensor sqrt() const; - Tensor & sqrt_() const; - Tensor std(bool unbiased=true) const; - Tensor std(IntArrayRef dim, bool unbiased=true, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor std(DimnameList dim, bool unbiased=true, bool keepdim=false) const; - #endif - Tensor prod(c10::optional dtype=c10::nullopt) const; - Tensor prod(int64_t dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const; - #ifdef BUILD_NAMEDTENSOR - Tensor prod(Dimname dim, bool keepdim=false, c10::optional dtype=c10::nullopt) const; - #endif - Tensor t() const; - Tensor & t_() const; - Tensor tan() const; - Tensor & tan_() const; - Tensor tanh() const; - Tensor & tanh_() const; - Tensor transpose(int64_t dim0, int64_t dim1) const; - #ifdef BUILD_NAMEDTENSOR - Tensor transpose(Dimname dim0, Dimname dim1) const; - #endif - Tensor & transpose_(int64_t dim0, int64_t dim1) const; - Tensor flip(IntArrayRef dims) const; - Tensor roll(IntArrayRef shifts, IntArrayRef dims={}) const; - Tensor rot90(int64_t k=1, IntArrayRef dims={0,1}) const; - Tensor trunc() const; - Tensor & trunc_() const; - Tensor type_as(const Tensor & other) const; - Tensor unsqueeze(int64_t dim) const; - Tensor & unsqueeze_(int64_t dim) const; - Tensor var(bool unbiased=true) const; - Tensor var(IntArrayRef dim, bool unbiased=true, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor var(DimnameList dim, bool unbiased=true, bool keepdim=false) const; - #endif - Tensor view_as(const Tensor & other) const; - Tensor where(const Tensor & condition, const Tensor & other) const; - Tensor norm(c10::optional p, ScalarType dtype) const; - Tensor norm(Scalar p=2) const; - Tensor norm(c10::optional p, IntArrayRef dim, bool keepdim, ScalarType dtype) const; - Tensor norm(c10::optional p, IntArrayRef dim, bool keepdim=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor norm(c10::optional p, DimnameList dim, bool keepdim, ScalarType dtype) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor norm(c10::optional p, DimnameList dim, bool keepdim=false) const; - #endif - Tensor clone() const; - Tensor & resize_as_(const Tensor & the_template) const; - Tensor pow(Scalar exponent) const; - Tensor & zero_() const; - Tensor sub(const Tensor & other, Scalar alpha=1) const; - Tensor & sub_(const Tensor & other, Scalar alpha=1) const; - Tensor sub(Scalar other, Scalar alpha=1) const; - Tensor & sub_(Scalar other, Scalar alpha=1) const; - Tensor addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const; - Tensor & addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const; - Tensor & sparse_resize_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const; - Tensor & sparse_resize_and_clear_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const; - Tensor sparse_mask(const Tensor & mask) const; - Tensor to_dense() const; - int64_t sparse_dim() const; - int64_t _dimI() const; - int64_t dense_dim() const; - int64_t _dimV() const; - int64_t _nnz() const; - Tensor coalesce() const; - bool is_coalesced() const; - Tensor _indices() const; - Tensor _values() const; - Tensor & _coalesced_(bool coalesced) const; - Tensor indices() const; - Tensor values() const; - int64_t numel() const; - std::vector unbind(int64_t dim=0) const; - #ifdef BUILD_NAMEDTENSOR - std::vector unbind(Dimname dim) const; - #endif - Tensor to_sparse(int64_t sparse_dim) const; - Tensor to_sparse() const; - Tensor to_mkldnn() const; - Tensor dequantize() const; - double q_scale() const; - int64_t q_zero_point() const; - Tensor q_per_channel_scales() const; - Tensor q_per_channel_zero_points() const; - int64_t q_per_channel_axis() const; - Tensor int_repr() const; - QScheme qscheme() const; - Tensor to(const TensorOptions & options, bool non_blocking=false, bool copy=false) const; - Tensor to(Device device, ScalarType dtype, bool non_blocking=false, bool copy=false) const; - Tensor to(ScalarType dtype, bool non_blocking=false, bool copy=false) const; - Tensor to(const Tensor & other, bool non_blocking=false, bool copy=false) const; - Scalar item() const; - Tensor & set_(Storage source) const; - Tensor & set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride={}) const; - Tensor & set_(const Tensor & source) const; - Tensor & set_() const; - Tensor & set_quantizer_(ConstQuantizerPtr quantizer) const; - bool is_set_to(const Tensor & tensor) const; - Tensor & masked_fill_(const Tensor & mask, Scalar value) const; - Tensor masked_fill(const Tensor & mask, Scalar value) const; - Tensor & masked_fill_(const Tensor & mask, const Tensor & value) const; - Tensor masked_fill(const Tensor & mask, const Tensor & value) const; - Tensor & masked_scatter_(const Tensor & mask, const Tensor & source) const; - Tensor masked_scatter(const Tensor & mask, const Tensor & source) const; - Tensor view(IntArrayRef size) const; - Tensor & put_(const Tensor & index, const Tensor & source, bool accumulate=false) const; - Tensor & index_add_(int64_t dim, const Tensor & index, const Tensor & source) const; - Tensor index_add(int64_t dim, const Tensor & index, const Tensor & source) const; - #ifdef BUILD_NAMEDTENSOR - Tensor index_add(Dimname dim, const Tensor & index, const Tensor & source) const; - #endif - Tensor & index_fill_(int64_t dim, const Tensor & index, Scalar value) const; - Tensor index_fill(int64_t dim, const Tensor & index, Scalar value) const; - Tensor & index_fill_(int64_t dim, const Tensor & index, const Tensor & value) const; - Tensor index_fill(int64_t dim, const Tensor & index, const Tensor & value) const; - #ifdef BUILD_NAMEDTENSOR - Tensor & index_fill_(Dimname dim, const Tensor & index, Scalar value) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor & index_fill_(Dimname dim, const Tensor & index, const Tensor & value) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor index_fill(Dimname dim, const Tensor & index, Scalar value) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor index_fill(Dimname dim, const Tensor & index, const Tensor & value) const; - #endif - Tensor & scatter_(int64_t dim, const Tensor & index, const Tensor & src) const; - Tensor scatter(int64_t dim, const Tensor & index, const Tensor & src) const; - Tensor & scatter_(int64_t dim, const Tensor & index, Scalar value) const; - Tensor scatter(int64_t dim, const Tensor & index, Scalar value) const; - #ifdef BUILD_NAMEDTENSOR - Tensor scatter(Dimname dim, const Tensor & index, const Tensor & src) const; - #endif - #ifdef BUILD_NAMEDTENSOR - Tensor scatter(Dimname dim, const Tensor & index, Scalar value) const; - #endif - Tensor & scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) const; - Tensor scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const; - #ifdef BUILD_NAMEDTENSOR - Tensor scatter_add(Dimname dim, const Tensor & index, const Tensor & src) const; - #endif - Tensor & lt_(Scalar other) const; - Tensor & lt_(const Tensor & other) const; - Tensor & gt_(Scalar other) const; - Tensor & gt_(const Tensor & other) const; - Tensor & le_(Scalar other) const; - Tensor & le_(const Tensor & other) const; - Tensor & ge_(Scalar other) const; - Tensor & ge_(const Tensor & other) const; - Tensor & eq_(Scalar other) const; - Tensor & eq_(const Tensor & other) const; - Tensor & ne_(Scalar other) const; - Tensor & ne_(const Tensor & other) const; - Tensor __and__(Scalar other) const; - Tensor __and__(const Tensor & other) const; - Tensor & __iand__(Scalar other) const; - Tensor & __iand__(const Tensor & other) const; - Tensor __or__(Scalar other) const; - Tensor __or__(const Tensor & other) const; - Tensor & __ior__(Scalar other) const; - Tensor & __ior__(const Tensor & other) const; - Tensor __xor__(Scalar other) const; - Tensor __xor__(const Tensor & other) const; - Tensor & __ixor__(Scalar other) const; - Tensor & __ixor__(const Tensor & other) const; - Tensor __lshift__(Scalar other) const; - Tensor __lshift__(const Tensor & other) const; - Tensor & __ilshift__(Scalar other) const; - Tensor & __ilshift__(const Tensor & other) const; - Tensor __rshift__(Scalar other) const; - Tensor __rshift__(const Tensor & other) const; - Tensor & __irshift__(Scalar other) const; - Tensor & __irshift__(const Tensor & other) const; - Tensor & lgamma_() const; - Tensor & atan2_(const Tensor & other) const; - Tensor & tril_(int64_t diagonal=0) const; - Tensor & triu_(int64_t diagonal=0) const; - Tensor & digamma_() const; - Tensor & polygamma_(int64_t n) const; - Tensor & renorm_(Scalar p, int64_t dim, Scalar maxnorm) const; - Tensor & pow_(Scalar exponent) const; - Tensor & pow_(const Tensor & exponent) const; - Tensor & lerp_(const Tensor & end, Scalar weight) const; - Tensor & lerp_(const Tensor & end, const Tensor & weight) const; - Tensor & fmod_(Scalar other) const; - Tensor & fmod_(const Tensor & other) const; - Tensor & remainder_(Scalar other) const; - Tensor & remainder_(const Tensor & other) const; - Tensor & addbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const; - Tensor addbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta=1, Scalar alpha=1) const; - Tensor & addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; - Tensor & random_(int64_t from, int64_t to, Generator * generator=nullptr) const; - Tensor & random_(int64_t to, Generator * generator=nullptr) const; - Tensor & random_(Generator * generator=nullptr) const; - Tensor & uniform_(double from=0, double to=1, Generator * generator=nullptr) const; - Tensor & normal_(double mean=0, double std=1, Generator * generator=nullptr) const; - Tensor & cauchy_(double median=0, double sigma=1, Generator * generator=nullptr) const; - Tensor & log_normal_(double mean=1, double std=2, Generator * generator=nullptr) const; - Tensor & exponential_(double lambd=1, Generator * generator=nullptr) const; - Tensor & geometric_(double p, Generator * generator=nullptr) const; - Tensor diag(int64_t diagonal=0) const; - Tensor cross(const Tensor & other, c10::optional dim=c10::nullopt) const; - Tensor triu(int64_t diagonal=0) const; - Tensor tril(int64_t diagonal=0) const; - Tensor trace() const; - Tensor ne(Scalar other) const; - Tensor ne(const Tensor & other) const; - Tensor eq(Scalar other) const; - Tensor eq(const Tensor & other) const; - Tensor ge(Scalar other) const; - Tensor ge(const Tensor & other) const; - Tensor le(Scalar other) const; - Tensor le(const Tensor & other) const; - Tensor gt(Scalar other) const; - Tensor gt(const Tensor & other) const; - Tensor lt(Scalar other) const; - Tensor lt(const Tensor & other) const; - Tensor take(const Tensor & index) const; - Tensor index_select(int64_t dim, const Tensor & index) const; - #ifdef BUILD_NAMEDTENSOR - Tensor index_select(Dimname dim, const Tensor & index) const; - #endif - Tensor masked_select(const Tensor & mask) const; - Tensor nonzero() const; - std::vector nonzero_numpy() const; - Tensor gather(int64_t dim, const Tensor & index, bool sparse_grad=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor gather(Dimname dim, const Tensor & index, bool sparse_grad=false) const; - #endif - Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; - Tensor & addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; - Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; - std::tuple lstsq(const Tensor & A) const; - std::tuple triangular_solve(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const; - std::tuple symeig(bool eigenvectors=false, bool upper=true) const; - std::tuple eig(bool eigenvectors=false) const; - std::tuple svd(bool some=true, bool compute_uv=true) const; - Tensor cholesky(bool upper=false) const; - Tensor cholesky_solve(const Tensor & input2, bool upper=false) const; - std::tuple solve(const Tensor & A) const; - Tensor cholesky_inverse(bool upper=false) const; - std::tuple qr(bool some=true) const; - std::tuple geqrf() const; - Tensor orgqr(const Tensor & input2) const; - Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const; - Tensor lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const; - Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const; - Tensor lgamma() const; - Tensor digamma() const; - Tensor polygamma(int64_t n) const; - Tensor erfinv() const; - Tensor & erfinv_() const; - Tensor sign() const; - Tensor & sign_() const; - Tensor dist(const Tensor & other, Scalar p=2) const; - Tensor atan2(const Tensor & other) const; - Tensor lerp(const Tensor & end, Scalar weight) const; - Tensor lerp(const Tensor & end, const Tensor & weight) const; - Tensor histc(int64_t bins=100, Scalar min=0, Scalar max=0) const; - Tensor fmod(Scalar other) const; - Tensor fmod(const Tensor & other) const; - Tensor remainder(Scalar other) const; - Tensor remainder(const Tensor & other) const; - Tensor min(const Tensor & other) const; - Tensor min() const; - Tensor max(const Tensor & other) const; - Tensor max() const; - Tensor median() const; - std::tuple sort(int64_t dim=-1, bool descending=false) const; - #ifdef BUILD_NAMEDTENSOR - std::tuple sort(Dimname dim, bool descending=false) const; - #endif - Tensor argsort(int64_t dim=-1, bool descending=false) const; - #ifdef BUILD_NAMEDTENSOR - Tensor argsort(Dimname dim, bool descending=false) const; - #endif - std::tuple topk(int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) const; - Tensor all() const; - Tensor any() const; - Tensor renorm(Scalar p, int64_t dim, Scalar maxnorm) const; - Tensor unfold(int64_t dimension, int64_t size, int64_t step) const; - bool equal(const Tensor & other) const; - Tensor pow(const Tensor & exponent) const; - Tensor alias() const; - - // We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the - // at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet. - // Before that change, we make this method to maintain BC for C++ usage like - // `x.to(y.dtype)`. - // TODO: remove following two after at::kDouble and its friends are TypeMeta's. - inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { - return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); - } - inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const { - return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy); - } - - template - auto m(F func, Args&&... params) const -> decltype(func(*this, std::forward(params)...)) { - return func(*this, std::forward(params)...); - } - -protected: - friend class ::caffe2::Tensor; - - void enforce_invariants(); - c10::intrusive_ptr impl_; -}; - -namespace detail { -// Helper creator for Tensor class which doesn't requires the users to pass -// in an intrusive_ptr instead it just converts the argument passed to -// requested intrusive_ptr type. -template -Tensor make_tensor(Args&&... args) { - return Tensor(c10::make_intrusive(std::forward(args)...)); -} - -} // namespace detail - -static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { - return legacyExtractTypeId(t.type_set()); -} - -} // namespace at diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h deleted file mode 100644 index ae225c014ccdd..0000000000000 --- a/aten/src/ATen/core/TensorMethods.h +++ /dev/null @@ -1,6234 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef USE_STATIC_DISPATCH -#include -#include -#include -#include -#endif - -namespace at { - -struct Quantizer; -// This is temporary typedef to enable Quantizer in aten native function API -// we'll remove them when we are actually exposing Quantizer class -// to frontend -using ConstQuantizerPtr = const c10::intrusive_ptr&; - -inline Tensor Tensor::cpu() const { - return to(options().device(DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); -} - -// TODO: The Python version also accepts arguments -inline Tensor Tensor::cuda() const { - return to(options().device(DeviceType::CUDA), /*non_blocking*/ false, /*copy*/ false); -} - -inline Tensor Tensor::hip() const { - return to(options().device(DeviceType::HIP), /*non_blocking*/ false, /*copy*/ false); -} - -inline Tensor Tensor::toType(ScalarType t) const { - return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false); -} - -// TODO: Deprecate me -inline Tensor Tensor::toBackend(Backend b) const { - return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false); -} - -inline TensorOptions Tensor::options() const { - return TensorOptions().dtype(dtype()) - .device(device()) - .layout(layout()) - .is_variable(is_variable()); -} - -// all static inline to allow for inlining of the non-dynamic part of dispatch -inline void Tensor::backward(const Tensor & gradient, bool keep_graph, bool create_graph) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - TypeDefault::backward(const_cast(*this), gradient, keep_graph, create_graph); -#else - static auto table = globalATenDispatch().getOpTable("aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void"); - return table->callUnboxed(const_cast(*this), gradient, keep_graph, create_graph); -#endif -} -inline void Tensor::set_data(const Tensor & new_data) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - TypeDefault::set_data(const_cast(*this), new_data); -#else - static auto table = globalATenDispatch().getOpTable("aten::set_data(Tensor(a!) self, Tensor new_data) -> void"); - return table->callUnboxed(const_cast(*this), new_data); -#endif -} -inline Tensor Tensor::data() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::data(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::data", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline bool Tensor::is_leaf() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_leaf(const_cast(*this)); -#else - static auto table = globalATenDispatch().getOpTable("aten::is_leaf(Tensor self) -> bool"); - return table->callUnboxed(const_cast(*this)); -#endif -} -inline int64_t Tensor::output_nr() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::output_nr(const_cast(*this)); -#else - static auto table = globalATenDispatch().getOpTable("aten::output_nr(Tensor self) -> int"); - return table->callUnboxed(const_cast(*this)); -#endif -} -inline int64_t Tensor::_version() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::_version(const_cast(*this)); -#else - static auto table = globalATenDispatch().getOpTable("aten::_version(Tensor self) -> int"); - return table->callUnboxed(const_cast(*this)); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor & Tensor::rename_(c10::optional names) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::rename_(const_cast(*this), names); -#else - static auto table = globalATenDispatch().getOpTable("aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)"); - return table->callUnboxed>(const_cast(*this), names); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::rename(c10::optional names) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::rename(const_cast(*this), names); -#else - static auto table = globalATenDispatch().getOpTable("aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)"); - return table->callUnboxed>(const_cast(*this), names); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::align_to(DimnameList names) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::align_to(const_cast(*this), names); -#else - static auto table = globalATenDispatch().getOpTable("aten::align_to(Tensor(a) self, DimnameList names) -> Tensor(a)"); - return table->callUnboxed(const_cast(*this), names); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::align_as(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::align_as(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::align_as", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::refine_names(DimnameList names) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::refine_names(const_cast(*this), names); -#else - static auto table = globalATenDispatch().getOpTable("aten::refine_names(Tensor(a) self, DimnameList names) -> Tensor(a)"); - return table->callUnboxed(const_cast(*this), names); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::unflatten(Dimname dim, IntArrayRef sizes, DimnameList names) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::unflatten(const_cast(*this), dim, sizes, names); -#else - static auto table = globalATenDispatch().getOpTable("aten::unflatten(Tensor self, Dimname dim, int[] sizes, DimnameList names) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, sizes, names); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::unflatten(int64_t dim, IntArrayRef sizes, DimnameList names) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::unflatten(const_cast(*this), dim, sizes, names); -#else - static auto table = globalATenDispatch().getOpTable("aten::unflatten(Tensor self, int dim, int[] sizes, DimnameList names) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, sizes, names); -#endif -} -#endif -inline Tensor Tensor::abs() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::abs(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::abs", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::abs_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::abs_(const_cast(*this)); - break; - default: - AT_ERROR("abs_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::abs_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::acos() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::acos(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::acos", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::acos_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::acos_(const_cast(*this)); - break; - default: - AT_ERROR("acos_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::acos_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::add(const_cast(*this), other, alpha); - break; - case Backend::SparseCPU: - return SparseCPUType::add(const_cast(*this), other, alpha); - break; - default: - AT_ERROR("add not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, alpha); -#endif -} -inline Tensor & Tensor::add_(const Tensor & other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::add_(const_cast(*this), other, alpha); - break; - case Backend::SparseCPU: - return SparseCPUType::add_(const_cast(*this), other, alpha); - break; - default: - AT_ERROR("add_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, alpha); -#endif -} -inline Tensor Tensor::add(Scalar other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::add(const_cast(*this), other, alpha); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other, alpha); -#endif -} -inline Tensor & Tensor::add_(Scalar other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::add_(const_cast(*this), other, alpha); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other, alpha); -#endif -} -inline Tensor Tensor::addmv(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::addmv(const_cast(*this), mat, vec, beta, alpha); - break; - default: - AT_ERROR("addmv not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmv", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec)), const_cast(*this), mat, vec, beta, alpha); -#endif -} -inline Tensor & Tensor::addmv_(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::addmv_(const_cast(*this), mat, vec, beta, alpha); - break; - default: - AT_ERROR("addmv_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmv_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec)), const_cast(*this), mat, vec, beta, alpha); -#endif -} -inline Tensor Tensor::addr(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::addr(const_cast(*this), vec1, vec2, beta, alpha); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addr", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2)), const_cast(*this), vec1, vec2, beta, alpha); -#endif -} -inline Tensor & Tensor::addr_(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::addr_(const_cast(*this), vec1, vec2, beta, alpha); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addr_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2)), const_cast(*this), vec1, vec2, beta, alpha); -#endif -} -inline Tensor Tensor::all(int64_t dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::all(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::all", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::all(Dimname dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::all(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, keepdim); -#endif -} -#endif -inline bool Tensor::allclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::allclose(const_cast(*this), other, rtol, atol, equal_nan); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::allclose", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, rtol, atol, equal_nan); -#endif -} -inline Tensor Tensor::any(int64_t dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::any(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::any", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::any(Dimname dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::any(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, keepdim); -#endif -} -#endif -inline Tensor Tensor::argmax(c10::optional dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::argmax(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::argmax", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -inline Tensor Tensor::argmin(c10::optional dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::argmin(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::argmin", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::as_strided(const_cast(*this), size, stride, storage_offset); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::as_strided(const_cast(*this), size, stride, storage_offset); - break; - default: - AT_ERROR("as_strided not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::as_strided", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size, stride, storage_offset); -#endif -} -inline Tensor & Tensor::as_strided_(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::as_strided_(const_cast(*this), size, stride, storage_offset); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::as_strided_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size, stride, storage_offset); -#endif -} -inline Tensor Tensor::asin() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::asin(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::asin", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::asin_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::asin_(const_cast(*this)); - break; - default: - AT_ERROR("asin_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::asin_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::atan() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::atan(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::atan_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::atan_(const_cast(*this)); - break; - default: - AT_ERROR("atan_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::baddbmm(const_cast(*this), batch1, batch2, beta, alpha); - break; - default: - AT_ERROR("baddbmm not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::baddbmm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2)), const_cast(*this), batch1, batch2, beta, alpha); -#endif -} -inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::baddbmm_(const_cast(*this), batch1, batch2, beta, alpha); - break; - default: - AT_ERROR("baddbmm_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::baddbmm_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2)), const_cast(*this), batch1, batch2, beta, alpha); -#endif -} -inline Tensor Tensor::bernoulli(Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::bernoulli(const_cast(*this), generator); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bernoulli", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), generator); -#endif -} -inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::bernoulli_(const_cast(*this), p, generator); - break; - default: - AT_ERROR("bernoulli_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bernoulli_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, p)), const_cast(*this), p, generator); -#endif -} -inline Tensor & Tensor::bernoulli_(double p, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::bernoulli_(const_cast(*this), p, generator); - break; - default: - AT_ERROR("bernoulli_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bernoulli_", "float"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p, generator); -#endif -} -inline Tensor Tensor::bernoulli(double p, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::bernoulli(const_cast(*this), p, generator); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bernoulli", "p"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p, generator); -#endif -} -inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::bincount(const_cast(*this), weights, minlength); - break; - default: - AT_ERROR("bincount not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"); - return table->callUnboxed(const_cast(*this), weights, minlength); -#endif -} -inline Tensor Tensor::bitwise_not() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::bitwise_not(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bitwise_not", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::bitwise_not_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::bitwise_not_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bitwise_not_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::logical_not() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::logical_not(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_not", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::logical_not_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::logical_not_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_not_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::logical_xor(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::logical_xor(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_xor", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::logical_xor_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::logical_xor_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_xor_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::bmm(const Tensor & mat2) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::bmm(const_cast(*this), mat2); - break; - default: - AT_ERROR("bmm not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bmm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat2)), const_cast(*this), mat2); -#endif -} -inline Tensor Tensor::ceil() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::ceil(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ceil", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::ceil_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::ceil_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ceil_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::chunk(const_cast(*this), chunks, dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::chunk", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, int64_t>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), chunks, dim); -#endif -} -inline Tensor Tensor::clamp(c10::optional min, c10::optional max) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::clamp(const_cast(*this), min, max); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed, c10::optional>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), min, max); -#endif -} -inline Tensor & Tensor::clamp_(c10::optional min, c10::optional max) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::clamp_(const_cast(*this), min, max); - break; - default: - AT_ERROR("clamp_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, c10::optional>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), min, max); -#endif -} -inline Tensor Tensor::clamp_max(Scalar max) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::clamp_max(const_cast(*this), max); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_max", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), max); -#endif -} -inline Tensor & Tensor::clamp_max_(Scalar max) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::clamp_max_(const_cast(*this), max); - break; - default: - AT_ERROR("clamp_max_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_max_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), max); -#endif -} -inline Tensor Tensor::clamp_min(Scalar min) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::clamp_min(const_cast(*this), min); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_min", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), min); -#endif -} -inline Tensor & Tensor::clamp_min_(Scalar min) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::clamp_min_(const_cast(*this), min); - break; - default: - AT_ERROR("clamp_min_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_min_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), min); -#endif -} -inline Tensor Tensor::contiguous(MemoryFormat memory_format) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::contiguous(const_cast(*this), memory_format); -#else - static auto table = globalATenDispatch().getOpTable("aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor"); - return table->callUnboxed(const_cast(*this), memory_format); -#endif -} -inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::copy_(const_cast(*this), src, non_blocking); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::copy_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, src)), const_cast(*this), src, non_blocking); -#endif -} -inline Tensor Tensor::cos() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cos(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cos", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::cos_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::cos_(const_cast(*this)); - break; - default: - AT_ERROR("cos_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cos_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::cosh() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cosh(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cosh", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::cosh_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::cosh_(const_cast(*this)); - break; - default: - AT_ERROR("cosh_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cosh_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cumsum(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::cumsum(Dimname dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cumsum(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#endif -inline Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cumprod(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::cumprod(Dimname dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cumprod(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#endif -inline Tensor Tensor::det() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::det(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::det", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::diag_embed(const_cast(*this), offset, dim1, dim2); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diag_embed", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), offset, dim1, dim2); -#endif -} -inline Tensor Tensor::diagflat(int64_t offset) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::diagflat(const_cast(*this), offset); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diagflat", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), offset); -#endif -} -inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::diagonal(const_cast(*this), offset, dim1, dim2); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diagonal", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), offset, dim1, dim2); -#endif -} -inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::fill_diagonal_(const_cast(*this), fill_value, wrap); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fill_diagonal_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), fill_value, wrap); -#endif -} -inline Tensor Tensor::div(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::div(const_cast(*this), other); - break; - case Backend::SparseCPU: - return SparseCPUType::div(const_cast(*this), other); - break; - default: - AT_ERROR("div not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::div_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::div_(const_cast(*this), other); - break; - case Backend::SparseCPU: - return SparseCPUType::div_(const_cast(*this), other); - break; - default: - AT_ERROR("div_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::div(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::div(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::div_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::div_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::dot(const Tensor & tensor) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::dot(const_cast(*this), tensor); - break; - default: - AT_ERROR("dot not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dot", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor)), const_cast(*this), tensor); -#endif -} -inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::new_empty(const_cast(*this), size, options); -#else - static auto table = globalATenDispatch().getOpTable("aten::new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); - return table->callUnboxed(const_cast(*this), size, options); -#endif -} -inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::new_full(const_cast(*this), size, fill_value, options); -#else - static auto table = globalATenDispatch().getOpTable("aten::new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); - return table->callUnboxed(const_cast(*this), size, fill_value, options); -#endif -} -inline Tensor & Tensor::resize_(IntArrayRef size) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::resize_(const_cast(*this), size); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::resize_(const_cast(*this), size); - break; - default: - AT_ERROR("resize_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::resize_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size); -#endif -} -inline Tensor Tensor::erf() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::erf(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erf", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::erf_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::erf_(const_cast(*this)); - break; - default: - AT_ERROR("erf_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erf_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::erfc() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::erfc(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfc", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::erfc_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::erfc_(const_cast(*this)); - break; - default: - AT_ERROR("erfc_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfc_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::exp() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::exp(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::exp", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::exp_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::exp_(const_cast(*this)); - break; - default: - AT_ERROR("exp_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::exp_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::expm1() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::expm1(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expm1", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::expm1_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::expm1_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expm1_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::expand(IntArrayRef size, bool implicit) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::expand(const_cast(*this), size, implicit); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expand", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size, implicit); -#endif -} -inline Tensor Tensor::expand_as(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::expand_as(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expand_as", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::flatten(const_cast(*this), start_dim, end_dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::flatten", "using_ints"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), start_dim, end_dim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, Dimname out_dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::flatten.named_out_dim(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor"); - return table->callUnboxed(const_cast(*this), start_dim, end_dim, out_dim); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::flatten(Dimname start_dim, Dimname end_dim, Dimname out_dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::flatten.using_names(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor"); - return table->callUnboxed(const_cast(*this), start_dim, end_dim, out_dim); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::flatten(DimnameList dims, Dimname out_dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::flatten(const_cast(*this), dims, out_dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::flatten.DimnameList(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor"); - return table->callUnboxed(const_cast(*this), dims, out_dim); -#endif -} -#endif -inline Tensor & Tensor::fill_(Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::fill_(const_cast(*this), value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fill_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), value); -#endif -} -inline Tensor & Tensor::fill_(const Tensor & value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::fill_(const_cast(*this), value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fill_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, value)), const_cast(*this), value); -#endif -} -inline Tensor Tensor::floor() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::floor(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::floor", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::floor_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::floor_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::floor_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::frac() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::frac(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::frac", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::frac_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::frac_(const_cast(*this)); - break; - default: - AT_ERROR("frac_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::frac_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::ger(const Tensor & vec2) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::ger(const_cast(*this), vec2); - break; - default: - AT_ERROR("ger not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ger", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec2)), const_cast(*this), vec2); -#endif -} -inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::fft(const_cast(*this), signal_ndim, normalized); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fft", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), signal_ndim, normalized); -#endif -} -inline Tensor Tensor::ifft(int64_t signal_ndim, bool normalized) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::ifft(const_cast(*this), signal_ndim, normalized); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ifft", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), signal_ndim, normalized); -#endif -} -inline Tensor Tensor::rfft(int64_t signal_ndim, bool normalized, bool onesided) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::rfft(const_cast(*this), signal_ndim, normalized, onesided); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rfft", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), signal_ndim, normalized, onesided); -#endif -} -inline Tensor Tensor::irfft(int64_t signal_ndim, bool normalized, bool onesided, IntArrayRef signal_sizes) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::irfft(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::irfft", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); -#endif -} -inline Tensor Tensor::index(TensorList indices) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index(const_cast(*this), indices); -#else - static auto table = globalATenDispatch().getOpTable("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"); - return table->callUnboxed(const_cast(*this), indices); -#endif -} -inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_copy_(const_cast(*this), dim, index, source); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_copy_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source)), const_cast(*this), dim, index, source); -#endif -} -inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_copy(const_cast(*this), dim, index, source); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_copy", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source)), const_cast(*this), dim, index, source); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor & Tensor::index_copy_(Dimname dim, const Tensor & index, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_copy_(const_cast(*this), dim, index, source); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), dim, index, source); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::index_copy(Dimname dim, const Tensor & index, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_copy(const_cast(*this), dim, index, source); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, source); -#endif -} -#endif -inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_put_(const_cast(*this), indices, values, accumulate); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), indices, values, accumulate); -#endif -} -inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_put(const_cast(*this), indices, values, accumulate); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), indices, values, accumulate); -#endif -} -inline Tensor Tensor::inverse() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::inverse(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::inverse", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::isclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::isclose(const_cast(*this), other, rtol, atol, equal_nan); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::isclose", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, rtol, atol, equal_nan); -#endif -} -inline bool Tensor::is_distributed() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_distributed(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_distributed", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline bool Tensor::is_floating_point() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_floating_point(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_floating_point", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline bool Tensor::is_complex() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_complex(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_complex", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline bool Tensor::is_nonzero() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_nonzero(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_nonzero", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline bool Tensor::is_same_size(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_same_size(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_same_size", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline bool Tensor::is_signed() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_signed(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_signed", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::kthvalue(const_cast(*this), k, dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::kthvalue", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, int64_t, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), k, dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline std::tuple Tensor::kthvalue(int64_t k, Dimname dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::kthvalue(const_cast(*this), k, dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->callUnboxed, const Tensor &, int64_t, Dimname, bool>(const_cast(*this), k, dim, keepdim); -#endif -} -#endif -inline Tensor Tensor::log() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::log(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::log_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::log_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::log10() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::log10(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log10", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::log10_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::log10_(const_cast(*this)); - break; - default: - AT_ERROR("log10_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log10_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::log1p() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::log1p(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log1p", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::log1p_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::log1p_(const_cast(*this)); - break; - case Backend::SparseCPU: - return SparseCPUType::log1p_(const_cast(*this)); - break; - default: - AT_ERROR("log1p_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log1p_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::log2() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::log2(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log2", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::log2_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::log2_(const_cast(*this)); - break; - default: - AT_ERROR("log2_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log2_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::logdet() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::logdet(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logdet", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::log_softmax(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::log_softmax(Dimname dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::log_softmax(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#endif -inline Tensor Tensor::logsumexp(IntArrayRef dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logsumexp", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::logsumexp(DimnameList dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, keepdim); -#endif -} -#endif -inline Tensor Tensor::matmul(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::matmul(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::matmul", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::matrix_power(int64_t n) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::matrix_power(const_cast(*this), n); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::matrix_power", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), n); -#endif -} -inline std::tuple Tensor::max(int64_t dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::max(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -inline Tensor Tensor::max_values(IntArrayRef dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::max_values(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max_values", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline std::tuple Tensor::max(Dimname dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::max(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->callUnboxed, const Tensor &, Dimname, bool>(const_cast(*this), dim, keepdim); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::max_values(DimnameList dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::max_values(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::max_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, keepdim); -#endif -} -#endif -inline Tensor Tensor::mean(c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::mean(const_cast(*this), dtype); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::mean(const_cast(*this), dtype); - break; - default: - AT_ERROR("mean not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dtype); -#endif -} -inline Tensor Tensor::mean(IntArrayRef dim, bool keepdim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::mean(const_cast(*this), dim, keepdim, dtype); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::mean(const_cast(*this), dim, keepdim, dtype); - break; - default: - AT_ERROR("mean not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, keepdim, dtype); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::mean(DimnameList dim, bool keepdim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::mean(const_cast(*this), dim, keepdim, dtype); - break; - default: - AT_ERROR("mean not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, keepdim, dtype); -#endif -} -#endif -inline std::tuple Tensor::median(int64_t dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::median(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::median", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline std::tuple Tensor::median(Dimname dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::median(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->callUnboxed, const Tensor &, Dimname, bool>(const_cast(*this), dim, keepdim); -#endif -} -#endif -inline std::tuple Tensor::min(int64_t dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::min(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -inline Tensor Tensor::min_values(IntArrayRef dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::min_values(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min_values", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline std::tuple Tensor::min(Dimname dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::min(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->callUnboxed, const Tensor &, Dimname, bool>(const_cast(*this), dim, keepdim); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::min_values(DimnameList dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::min_values(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::min_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, keepdim); -#endif -} -#endif -inline Tensor Tensor::mm(const Tensor & mat2) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::mm(const_cast(*this), mat2); - break; - case Backend::SparseCPU: - return SparseCPUType::mm(const_cast(*this), mat2); - break; - default: - AT_ERROR("mm not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat2)), const_cast(*this), mat2); -#endif -} -inline std::tuple Tensor::mode(int64_t dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::mode(const_cast(*this), dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mode", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline std::tuple Tensor::mode(Dimname dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::mode(const_cast(*this), dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->callUnboxed, const Tensor &, Dimname, bool>(const_cast(*this), dim, keepdim); -#endif -} -#endif -inline Tensor Tensor::mul(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::mul(const_cast(*this), other); - break; - case Backend::SparseCPU: - return SparseCPUType::mul(const_cast(*this), other); - break; - default: - AT_ERROR("mul not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::mul_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::mul_(const_cast(*this), other); - break; - case Backend::SparseCPU: - return SparseCPUType::mul_(const_cast(*this), other); - break; - default: - AT_ERROR("mul_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::mul(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::mul(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::mul_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::mul_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::mv(const Tensor & vec) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::mv(const_cast(*this), vec); - break; - default: - AT_ERROR("mv not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mv", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec)), const_cast(*this), vec); -#endif -} -inline Tensor Tensor::mvlgamma(int64_t p) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::mvlgamma(const_cast(*this), p); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mvlgamma", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p); -#endif -} -inline Tensor & Tensor::mvlgamma_(int64_t p) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::mvlgamma_(const_cast(*this), p); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mvlgamma_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p); -#endif -} -inline Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::narrow_copy(const_cast(*this), dim, start, length); - break; - case Backend::SparseCPU: - return SparseCPUType::narrow_copy(const_cast(*this), dim, start, length); - break; - default: - AT_ERROR("narrow_copy not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::narrow_copy", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, start, length); -#endif -} -inline Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::narrow(const_cast(*this), dim, start, length); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::narrow", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, start, length); -#endif -} -inline Tensor Tensor::permute(IntArrayRef dims) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::permute(const_cast(*this), dims); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::permute", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dims); -#endif -} -inline Tensor Tensor::numpy_T() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::numpy_T(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::numpy_T", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline bool Tensor::is_pinned() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::is_pinned(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_pinned", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::pin_memory() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::pin_memory(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pin_memory", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::pinverse(double rcond) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::pinverse(const_cast(*this), rcond); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pinverse", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), rcond); -#endif -} -inline Tensor Tensor::reciprocal() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::reciprocal(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reciprocal", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::reciprocal_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::reciprocal_(const_cast(*this)); - break; - default: - AT_ERROR("reciprocal_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reciprocal_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::neg() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::neg(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::neg", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::neg_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::neg_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::neg_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::repeat(IntArrayRef repeats) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::repeat(const_cast(*this), repeats); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::repeat", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), repeats); -#endif -} -inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::repeat_interleave(const_cast(*this), repeats, dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::repeat_interleave", "self_Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, repeats)), const_cast(*this), repeats, dim); -#endif -} -inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::repeat_interleave(const_cast(*this), repeats, dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::repeat_interleave", "self_int"}).value(); - return c10::Dispatcher::singleton().callUnboxed>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), repeats, dim); -#endif -} -inline Tensor Tensor::reshape(IntArrayRef shape) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::reshape(const_cast(*this), shape); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reshape", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), shape); -#endif -} -inline Tensor Tensor::reshape_as(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::reshape_as(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reshape_as", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::round() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::round(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::round", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::round_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::round_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::round_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::relu() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::relu(const_cast(*this)); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::relu(const_cast(*this)); - break; - default: - AT_ERROR("relu not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::relu", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::relu_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::relu_(const_cast(*this)); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::relu_(const_cast(*this)); - break; - default: - AT_ERROR("relu_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::relu_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::prelu(const Tensor & weight) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::prelu(const_cast(*this), weight); - break; - default: - AT_ERROR("prelu not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::prelu", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, weight)), const_cast(*this), weight); -#endif -} -inline std::tuple Tensor::prelu_backward(const Tensor & grad_output, const Tensor & weight) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::prelu_backward(grad_output, const_cast(*this), weight); - break; - default: - AT_ERROR("prelu_backward not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::prelu_backward", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, const Tensor &, const Tensor &>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(grad_output, *this, weight)), grad_output, const_cast(*this), weight); -#endif -} -inline Tensor Tensor::hardshrink(Scalar lambd) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::hardshrink(const_cast(*this), lambd); - break; - default: - AT_ERROR("hardshrink not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::hardshrink", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), lambd); -#endif -} -inline Tensor Tensor::hardshrink_backward(const Tensor & grad_out, Scalar lambd) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::hardshrink_backward(grad_out, const_cast(*this), lambd); - break; - default: - AT_ERROR("hardshrink_backward not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::hardshrink_backward", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(grad_out, *this)), grad_out, const_cast(*this), lambd); -#endif -} -inline Tensor Tensor::rsqrt() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::rsqrt(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rsqrt", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::rsqrt_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::rsqrt_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rsqrt_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::select(Dimname dim, int64_t index) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::select(const_cast(*this), dim, index); -#else - static auto table = globalATenDispatch().getOpTable("aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)"); - return table->callUnboxed(const_cast(*this), dim, index); -#endif -} -#endif -inline Tensor Tensor::select(int64_t dim, int64_t index) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::select(const_cast(*this), dim, index); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::select", "int"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, index); -#endif -} -inline Tensor Tensor::sigmoid() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sigmoid(const_cast(*this)); - break; - default: - AT_ERROR("sigmoid not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sigmoid", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::sigmoid_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sigmoid_(const_cast(*this)); - break; - default: - AT_ERROR("sigmoid_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sigmoid_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::sin() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sin(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sin", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::sin_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sin_(const_cast(*this)); - break; - default: - AT_ERROR("sin_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sin_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::sinh() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sinh(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sinh", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::sinh_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sinh_(const_cast(*this)); - break; - default: - AT_ERROR("sinh_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sinh_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::detach() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::detach(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::detach", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::detach_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::detach_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::detach_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::size(int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::size(const_cast(*this), dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::size", "int"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline int64_t Tensor::size(Dimname dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::size(const_cast(*this), dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::size.Dimname(Tensor self, Dimname dim) -> int"); - return table->callUnboxed(const_cast(*this), dim); -#endif -} -#endif -inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t step) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::slice(const_cast(*this), dim, start, end, step); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::slice", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, start, end, step); -#endif -} -inline std::tuple Tensor::slogdet() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::slogdet(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::slogdet", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::smm(const Tensor & mat2) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::smm(const_cast(*this), mat2); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::smm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat2)), const_cast(*this), mat2); -#endif -} -inline Tensor Tensor::softmax(int64_t dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::softmax(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::softmax(Dimname dim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::softmax(const_cast(*this), dim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, dtype); -#endif -} -#endif -inline std::vector Tensor::split(int64_t split_size, int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::split(const_cast(*this), split_size, dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::split", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, int64_t>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), split_size, dim); -#endif -} -inline std::vector Tensor::split_with_sizes(IntArrayRef split_sizes, int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::split_with_sizes(const_cast(*this), split_sizes, dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::split_with_sizes", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, IntArrayRef, int64_t>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), split_sizes, dim); -#endif -} -inline Tensor Tensor::squeeze() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::squeeze(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::squeeze(int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::squeeze(const_cast(*this), dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::squeeze(Dimname dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::squeeze(const_cast(*this), dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)"); - return table->callUnboxed(const_cast(*this), dim); -#endif -} -#endif -inline Tensor & Tensor::squeeze_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::squeeze_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::squeeze_(int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::squeeze_(const_cast(*this), dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze_", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor & Tensor::squeeze_(Dimname dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::squeeze_(const_cast(*this), dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), dim); -#endif -} -#endif -inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sspaddmm(const_cast(*this), mat1, mat2, beta, alpha); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sspaddmm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2)), const_cast(*this), mat1, mat2, beta, alpha); -#endif -} -inline Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const Tensor & window, bool normalized, bool onesided) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::stft(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); -#else - static auto table = globalATenDispatch().getOpTable("aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool onesided=True) -> Tensor"); - return table->callUnboxed, c10::optional, const Tensor &, bool, bool>(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); -#endif -} -inline int64_t Tensor::stride(int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::stride(const_cast(*this), dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::stride", "int"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline int64_t Tensor::stride(Dimname dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::stride(const_cast(*this), dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::stride.Dimname(Tensor self, Dimname dim) -> int"); - return table->callUnboxed(const_cast(*this), dim); -#endif -} -#endif -inline Tensor Tensor::sum(c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sum(const_cast(*this), dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dtype); -#endif -} -inline Tensor Tensor::sum(IntArrayRef dim, bool keepdim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sum(const_cast(*this), dim, keepdim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, keepdim, dtype); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::sum(DimnameList dim, bool keepdim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sum(const_cast(*this), dim, keepdim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, keepdim, dtype); -#endif -} -#endif -inline Tensor Tensor::sum_to_size(IntArrayRef size) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sum_to_size(const_cast(*this), size); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sum_to_size", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size); -#endif -} -inline Tensor Tensor::sqrt() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sqrt(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sqrt", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::sqrt_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sqrt_(const_cast(*this)); - break; - default: - AT_ERROR("sqrt_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sqrt_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::std(bool unbiased) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::std(const_cast(*this), unbiased); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::std", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), unbiased); -#endif -} -inline Tensor Tensor::std(IntArrayRef dim, bool unbiased, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::std", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, unbiased, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::std(DimnameList dim, bool unbiased, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, unbiased, keepdim); -#endif -} -#endif -inline Tensor Tensor::prod(c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::prod(const_cast(*this), dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dtype); -#endif -} -inline Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, keepdim, dtype); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::prod(Dimname dim, bool keepdim, c10::optional dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->callUnboxed>(const_cast(*this), dim, keepdim, dtype); -#endif -} -#endif -inline Tensor Tensor::t() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::t(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::t", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::t_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::t_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::t_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::tan() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::tan(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tan", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::tan_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::tan_(const_cast(*this)); - break; - default: - AT_ERROR("tan_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tan_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::tanh() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::tanh(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tanh", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::tanh_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::tanh_(const_cast(*this)); - break; - default: - AT_ERROR("tanh_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tanh_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::transpose(const_cast(*this), dim0, dim1); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::transpose", "int"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim0, dim1); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::transpose(Dimname dim0, Dimname dim1) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::transpose(const_cast(*this), dim0, dim1); -#else - static auto table = globalATenDispatch().getOpTable("aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)"); - return table->callUnboxed(const_cast(*this), dim0, dim1); -#endif -} -#endif -inline Tensor & Tensor::transpose_(int64_t dim0, int64_t dim1) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::transpose_(const_cast(*this), dim0, dim1); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::transpose_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim0, dim1); -#endif -} -inline Tensor Tensor::flip(IntArrayRef dims) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::flip(const_cast(*this), dims); - break; - default: - AT_ERROR("flip not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::flip", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dims); -#endif -} -inline Tensor Tensor::roll(IntArrayRef shifts, IntArrayRef dims) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::roll(const_cast(*this), shifts, dims); - break; - default: - AT_ERROR("roll not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::roll", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), shifts, dims); -#endif -} -inline Tensor Tensor::rot90(int64_t k, IntArrayRef dims) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::rot90(const_cast(*this), k, dims); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rot90", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), k, dims); -#endif -} -inline Tensor Tensor::trunc() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::trunc(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::trunc", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::trunc_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::trunc_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::trunc_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::type_as(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::type_as(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::type_as", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::unsqueeze(int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::unsqueeze(const_cast(*this), dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unsqueeze", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim); -#endif -} -inline Tensor & Tensor::unsqueeze_(int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::unsqueeze_(const_cast(*this), dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unsqueeze_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim); -#endif -} -inline Tensor Tensor::var(bool unbiased) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::var(const_cast(*this), unbiased); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::var", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), unbiased); -#endif -} -inline Tensor Tensor::var(IntArrayRef dim, bool unbiased, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::var", "dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, unbiased, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::var(DimnameList dim, bool unbiased, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, unbiased, keepdim); -#endif -} -#endif -inline Tensor Tensor::view_as(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::view_as(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::view_as", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::where(condition, const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::where", "self"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(condition, *this, other)), condition, const_cast(*this), other); -#endif -} -inline Tensor Tensor::norm(c10::optional p, ScalarType dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::norm(const_cast(*this), p, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor"); - return table->callUnboxed, ScalarType>(const_cast(*this), p, dtype); -#endif -} -inline Tensor Tensor::norm(Scalar p) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::norm(const_cast(*this), p); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::norm", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p); -#endif -} -inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim, ScalarType dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); - return table->callUnboxed, IntArrayRef, bool, ScalarType>(const_cast(*this), p, dim, keepdim, dtype); -#endif -} -inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::norm(const_cast(*this), p, dim, keepdim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::norm", "ScalarOpt_dim"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, IntArrayRef, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p, dim, keepdim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdim, ScalarType dtype) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); -#else - static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); - return table->callUnboxed, DimnameList, bool, ScalarType>(const_cast(*this), p, dim, keepdim, dtype); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::norm(const_cast(*this), p, dim, keepdim); -#else - static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->callUnboxed, DimnameList, bool>(const_cast(*this), p, dim, keepdim); -#endif -} -#endif -inline Tensor Tensor::clone() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::clone(const_cast(*this)); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::clone(const_cast(*this)); - break; - case Backend::SparseCPU: - return SparseCPUType::clone(const_cast(*this)); - break; - default: - AT_ERROR("clone not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clone", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::resize_as_(const Tensor & the_template) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::resize_as_(const_cast(*this), the_template); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::resize_as_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, the_template)), const_cast(*this), the_template); -#endif -} -inline Tensor Tensor::pow(Scalar exponent) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::pow(const_cast(*this), exponent); - break; - case Backend::SparseCPU: - return SparseCPUType::pow(const_cast(*this), exponent); - break; - default: - AT_ERROR("pow not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow", "Tensor_Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), exponent); -#endif -} -inline Tensor & Tensor::zero_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::zero_(const_cast(*this)); - break; - case Backend::SparseCPU: - return SparseCPUType::zero_(const_cast(*this)); - break; - default: - AT_ERROR("zero_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::zero_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::sub(const Tensor & other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sub(const_cast(*this), other, alpha); - break; - case Backend::SparseCPU: - return SparseCPUType::sub(const_cast(*this), other, alpha); - break; - default: - AT_ERROR("sub not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, alpha); -#endif -} -inline Tensor & Tensor::sub_(const Tensor & other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sub_(const_cast(*this), other, alpha); - break; - case Backend::SparseCPU: - return SparseCPUType::sub_(const_cast(*this), other, alpha); - break; - default: - AT_ERROR("sub_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, alpha); -#endif -} -inline Tensor Tensor::sub(Scalar other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sub(const_cast(*this), other, alpha); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other, alpha); -#endif -} -inline Tensor & Tensor::sub_(Scalar other, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sub_(const_cast(*this), other, alpha); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other, alpha); -#endif -} -inline Tensor Tensor::addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); - break; - case Backend::SparseCPU: - return SparseCPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); - break; - default: - AT_ERROR("addmm not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2)), const_cast(*this), mat1, mat2, beta, alpha); -#endif -} -inline Tensor & Tensor::addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); - break; - case Backend::SparseCPU: - return SparseCPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); - break; - default: - AT_ERROR("addmm_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmm_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2)), const_cast(*this), mat1, mat2, beta, alpha); -#endif -} -inline Tensor & Tensor::sparse_resize_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::sparse_resize_(const_cast(*this), size, sparse_dim, dense_dim); - break; - default: - AT_ERROR("sparse_resize_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_resize_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size, sparse_dim, dense_dim); -#endif -} -inline Tensor & Tensor::sparse_resize_and_clear_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::sparse_resize_and_clear_(const_cast(*this), size, sparse_dim, dense_dim); - break; - default: - AT_ERROR("sparse_resize_and_clear_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_resize_and_clear_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size, sparse_dim, dense_dim); -#endif -} -inline Tensor Tensor::sparse_mask(const Tensor & mask) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::sparse_mask(const_cast(*this), mask); - break; - default: - AT_ERROR("sparse_mask not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_mask", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask)), const_cast(*this), mask); -#endif -} -inline Tensor Tensor::to_dense() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::to_dense(const_cast(*this)); - break; - default: - AT_ERROR("to_dense not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_dense", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::sparse_dim() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::sparse_dim(const_cast(*this)); - break; - default: - AT_ERROR("sparse_dim not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_dim", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::_dimI() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::_dimI(const_cast(*this)); - break; - default: - AT_ERROR("_dimI not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_dimI", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::dense_dim() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::dense_dim(const_cast(*this)); - break; - default: - AT_ERROR("dense_dim not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dense_dim", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::_dimV() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::_dimV(const_cast(*this)); - break; - default: - AT_ERROR("_dimV not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_dimV", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::_nnz() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::_nnz(const_cast(*this)); - break; - default: - AT_ERROR("_nnz not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_nnz", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::coalesce() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::coalesce(const_cast(*this)); - break; - default: - AT_ERROR("coalesce not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::coalesce", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline bool Tensor::is_coalesced() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::is_coalesced(const_cast(*this)); - break; - default: - AT_ERROR("is_coalesced not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_coalesced", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::_indices() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::_indices(const_cast(*this)); - break; - default: - AT_ERROR("_indices not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_indices", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::_values() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::_values(const_cast(*this)); - break; - default: - AT_ERROR("_values not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_values", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::_coalesced_(bool coalesced) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::_coalesced_(const_cast(*this), coalesced); - break; - default: - AT_ERROR("_coalesced_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_coalesced_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), coalesced); -#endif -} -inline Tensor Tensor::indices() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::indices(const_cast(*this)); - break; - default: - AT_ERROR("indices not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::indices", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::values() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::SparseCPU: - return SparseCPUType::values(const_cast(*this)); - break; - default: - AT_ERROR("values not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::values", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::numel() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::numel(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::numel", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline std::vector Tensor::unbind(int64_t dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::unbind(const_cast(*this), dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unbind", "int"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline std::vector Tensor::unbind(Dimname dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::unbind(const_cast(*this), dim); -#else - static auto table = globalATenDispatch().getOpTable("aten::unbind.Dimname(Tensor(a) self, Dimname dim) -> Tensor(a)[]"); - return table->callUnboxed, const Tensor &, Dimname>(const_cast(*this), dim); -#endif -} -#endif -inline Tensor Tensor::to_sparse(int64_t sparse_dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::to_sparse(const_cast(*this), sparse_dim); - break; - default: - AT_ERROR("to_sparse not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_sparse", "sparse_dim"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), sparse_dim); -#endif -} -inline Tensor Tensor::to_sparse() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::to_sparse(const_cast(*this)); - break; - default: - AT_ERROR("to_sparse not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_sparse", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::to_mkldnn() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::to_mkldnn(const_cast(*this)); - break; - default: - AT_ERROR("to_mkldnn not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_mkldnn", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::dequantize() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::dequantize(const_cast(*this)); - break; - default: - AT_ERROR("dequantize not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dequantize", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline double Tensor::q_scale() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::q_scale(const_cast(*this)); - break; - default: - AT_ERROR("q_scale not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_scale", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::q_zero_point() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::q_zero_point(const_cast(*this)); - break; - default: - AT_ERROR("q_zero_point not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_zero_point", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::q_per_channel_scales() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::q_per_channel_scales(const_cast(*this)); - break; - default: - AT_ERROR("q_per_channel_scales not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_per_channel_scales", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::q_per_channel_zero_points() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::q_per_channel_zero_points(const_cast(*this)); - break; - default: - AT_ERROR("q_per_channel_zero_points not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_per_channel_zero_points", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline int64_t Tensor::q_per_channel_axis() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::q_per_channel_axis(const_cast(*this)); - break; - default: - AT_ERROR("q_per_channel_axis not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_axis(Tensor self) -> int"); - return table->callUnboxed(const_cast(*this)); -#endif -} -inline Tensor Tensor::int_repr() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::int_repr(const_cast(*this)); - break; - default: - AT_ERROR("int_repr not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::int_repr", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline QScheme Tensor::qscheme() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::qscheme(const_cast(*this)); - break; - default: - AT_ERROR("qscheme not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::qscheme(Tensor self) -> QScheme"); - return table->callUnboxed(const_cast(*this)); -#endif -} -inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool copy) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::to(const_cast(*this), options, non_blocking, copy); -#else - static auto table = globalATenDispatch().getOpTable("aten::to.dtype_layout(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), options, non_blocking, copy); -#endif -} -inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, bool copy) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::to(const_cast(*this), device, dtype, non_blocking, copy); -#else - static auto table = globalATenDispatch().getOpTable("aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), device, dtype, non_blocking, copy); -#endif -} -inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::to(const_cast(*this), dtype, non_blocking, copy); -#else - static auto table = globalATenDispatch().getOpTable("aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dtype, non_blocking, copy); -#endif -} -inline Tensor Tensor::to(const Tensor & other, bool non_blocking, bool copy) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::to(const_cast(*this), other, non_blocking, copy); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to", "other"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, non_blocking, copy); -#endif -} -inline Scalar Tensor::item() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::item(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::item", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::set_(Storage source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::set_(const_cast(*this), source); - break; - default: - AT_ERROR("set_ not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), source); -#endif -} -inline Tensor & Tensor::set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::set_(const_cast(*this), source, storage_offset, size, stride); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::set_(const_cast(*this), source, storage_offset, size, stride); - break; - default: - AT_ERROR("set_ not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), source, storage_offset, size, stride); -#endif -} -inline Tensor & Tensor::set_(const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::set_(const_cast(*this), source); - break; - default: - AT_ERROR("set_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::set_", "source_Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, source)), const_cast(*this), source); -#endif -} -inline Tensor & Tensor::set_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::set_(const_cast(*this)); - break; - default: - AT_ERROR("set_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::set_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::set_quantizer_(ConstQuantizerPtr quantizer) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::QuantizedCPU: - return QuantizedCPUType::set_quantizer_(const_cast(*this), quantizer); - break; - default: - AT_ERROR("set_quantizer_ not implemented for ", at::toString(type_set())); - } -#else - static auto table = globalATenDispatch().getOpTable("aten::set_quantizer_(Tensor(a!) self, ConstQuantizerPtr quantizer) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), quantizer); -#endif -} -inline bool Tensor::is_set_to(const Tensor & tensor) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::is_set_to(const_cast(*this), tensor); - break; - default: - AT_ERROR("is_set_to not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_set_to", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor)), const_cast(*this), tensor); -#endif -} -inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::masked_fill_(const_cast(*this), mask, value); - break; - default: - AT_ERROR("masked_fill_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask)), const_cast(*this), mask, value); -#endif -} -inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::masked_fill(const_cast(*this), mask, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask)), const_cast(*this), mask, value); -#endif -} -inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::masked_fill_(const_cast(*this), mask, value); - break; - default: - AT_ERROR("masked_fill_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, value)), const_cast(*this), mask, value); -#endif -} -inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::masked_fill(const_cast(*this), mask, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, value)), const_cast(*this), mask, value); -#endif -} -inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::masked_scatter_(const_cast(*this), mask, source); - break; - default: - AT_ERROR("masked_scatter_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_scatter_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, source)), const_cast(*this), mask, source); -#endif -} -inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::masked_scatter(const_cast(*this), mask, source); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_scatter", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, source)), const_cast(*this), mask, source); -#endif -} -inline Tensor Tensor::view(IntArrayRef size) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::view(const_cast(*this), size); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::view(const_cast(*this), size); - break; - default: - AT_ERROR("view not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::view", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), size); -#endif -} -inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool accumulate) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::put_(const_cast(*this), index, source, accumulate); - break; - default: - AT_ERROR("put_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::put_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source)), const_cast(*this), index, source, accumulate); -#endif -} -inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::index_add_(const_cast(*this), dim, index, source); - break; - default: - AT_ERROR("index_add_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_add_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source)), const_cast(*this), dim, index, source); -#endif -} -inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_add(const_cast(*this), dim, index, source); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_add", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source)), const_cast(*this), dim, index, source); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::index_add(Dimname dim, const Tensor & index, const Tensor & source) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_add(const_cast(*this), dim, index, source); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, source); -#endif -} -#endif -inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::index_fill_(const_cast(*this), dim, index, value); - break; - default: - AT_ERROR("index_fill_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index)), const_cast(*this), dim, index, value); -#endif -} -inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_fill(const_cast(*this), dim, index, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index)), const_cast(*this), dim, index, value); -#endif -} -inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Tensor & value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::index_fill_(const_cast(*this), dim, index, value); - break; - default: - AT_ERROR("index_fill_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, value)), const_cast(*this), dim, index, value); -#endif -} -inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor & value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_fill(const_cast(*this), dim, index, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, value)), const_cast(*this), dim, index, value); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor & Tensor::index_fill_(Dimname dim, const Tensor & index, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_fill_(const_cast(*this), dim, index, value); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_fill_.dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), dim, index, value); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor & Tensor::index_fill_(Dimname dim, const Tensor & index, const Tensor & value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_fill_(const_cast(*this), dim, index, value); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_fill_.dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!)"); - return table->callUnboxed(const_cast(*this), dim, index, value); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::index_fill(Dimname dim, const Tensor & index, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_fill(const_cast(*this), dim, index, value); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_fill.dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, value); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::index_fill(Dimname dim, const Tensor & index, const Tensor & value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_fill(const_cast(*this), dim, index, value); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_fill.dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, value); -#endif -} -#endif -inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor & src) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::scatter_(const_cast(*this), dim, index, src); - break; - default: - AT_ERROR("scatter_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_", "src"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src)), const_cast(*this), dim, index, src); -#endif -} -inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & src) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::scatter(const_cast(*this), dim, index, src); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter", "src"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src)), const_cast(*this), dim, index, src); -#endif -} -inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::scatter_(const_cast(*this), dim, index, value); - break; - default: - AT_ERROR("scatter_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_", "value"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index)), const_cast(*this), dim, index, value); -#endif -} -inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::scatter(const_cast(*this), dim, index, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter", "value"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index)), const_cast(*this), dim, index, value); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::scatter(Dimname dim, const Tensor & index, const Tensor & src) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::scatter(const_cast(*this), dim, index, src); -#else - static auto table = globalATenDispatch().getOpTable("aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, src); -#endif -} -#endif -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::scatter(Dimname dim, const Tensor & index, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::scatter(const_cast(*this), dim, index, value); -#else - static auto table = globalATenDispatch().getOpTable("aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, value); -#endif -} -#endif -inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::scatter_add_(const_cast(*this), dim, index, src); - break; - default: - AT_ERROR("scatter_add_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_add_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src)), const_cast(*this), dim, index, src); -#endif -} -inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::scatter_add(const_cast(*this), dim, index, src); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_add", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src)), const_cast(*this), dim, index, src); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::scatter_add(Dimname dim, const Tensor & index, const Tensor & src) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::scatter_add(const_cast(*this), dim, index, src); -#else - static auto table = globalATenDispatch().getOpTable("aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, src); -#endif -} -#endif -inline Tensor & Tensor::lt_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::lt_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::lt_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::lt_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::gt_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::gt_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::gt_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::gt_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::le_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::le_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::le_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::le_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::ge_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::ge_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::ge_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::ge_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::eq_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::eq_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::eq_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::eq_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::ne_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::ne_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::ne_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::ne_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__and__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__and__(const_cast(*this), other); - break; - default: - AT_ERROR("__and__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__and__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__and__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__and__(const_cast(*this), other); - break; - default: - AT_ERROR("__and__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__and__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__iand__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__iand__(const_cast(*this), other); - break; - default: - AT_ERROR("__iand__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__iand__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__iand__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__iand__(const_cast(*this), other); - break; - default: - AT_ERROR("__iand__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__iand__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__or__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__or__(const_cast(*this), other); - break; - default: - AT_ERROR("__or__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__or__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__or__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__or__(const_cast(*this), other); - break; - default: - AT_ERROR("__or__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__or__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__ior__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__ior__(const_cast(*this), other); - break; - default: - AT_ERROR("__ior__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ior__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__ior__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__ior__(const_cast(*this), other); - break; - default: - AT_ERROR("__ior__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ior__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__xor__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__xor__(const_cast(*this), other); - break; - default: - AT_ERROR("__xor__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__xor__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__xor__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__xor__(const_cast(*this), other); - break; - default: - AT_ERROR("__xor__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__xor__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__ixor__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__ixor__(const_cast(*this), other); - break; - default: - AT_ERROR("__ixor__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ixor__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__ixor__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__ixor__(const_cast(*this), other); - break; - default: - AT_ERROR("__ixor__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ixor__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__lshift__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__lshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__lshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__lshift__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__lshift__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__lshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__lshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__lshift__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__ilshift__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__ilshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__ilshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ilshift__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__ilshift__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__ilshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__ilshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ilshift__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__rshift__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__rshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__rshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__rshift__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::__rshift__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__rshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__rshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__rshift__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__irshift__(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__irshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__irshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__irshift__", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::__irshift__(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::__irshift__(const_cast(*this), other); - break; - default: - AT_ERROR("__irshift__ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__irshift__", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::lgamma_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lgamma_(const_cast(*this)); - break; - default: - AT_ERROR("lgamma_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lgamma_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::atan2_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::atan2_(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan2_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::tril_(int64_t diagonal) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::tril_(const_cast(*this), diagonal); - break; - default: - AT_ERROR("tril_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tril_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), diagonal); -#endif -} -inline Tensor & Tensor::triu_(int64_t diagonal) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::triu_(const_cast(*this), diagonal); - break; - default: - AT_ERROR("triu_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::triu_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), diagonal); -#endif -} -inline Tensor & Tensor::digamma_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::digamma_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::digamma_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::polygamma_(int64_t n) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::polygamma_(const_cast(*this), n); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::polygamma_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), n); -#endif -} -inline Tensor & Tensor::renorm_(Scalar p, int64_t dim, Scalar maxnorm) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::renorm_(const_cast(*this), p, dim, maxnorm); - break; - default: - AT_ERROR("renorm_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::renorm_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p, dim, maxnorm); -#endif -} -inline Tensor & Tensor::pow_(Scalar exponent) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::pow_(const_cast(*this), exponent); - break; - default: - AT_ERROR("pow_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), exponent); -#endif -} -inline Tensor & Tensor::pow_(const Tensor & exponent) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::pow_(const_cast(*this), exponent); - break; - default: - AT_ERROR("pow_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, exponent)), const_cast(*this), exponent); -#endif -} -inline Tensor & Tensor::lerp_(const Tensor & end, Scalar weight) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lerp_(const_cast(*this), end, weight); - break; - default: - AT_ERROR("lerp_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end)), const_cast(*this), end, weight); -#endif -} -inline Tensor & Tensor::lerp_(const Tensor & end, const Tensor & weight) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lerp_(const_cast(*this), end, weight); - break; - default: - AT_ERROR("lerp_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end, weight)), const_cast(*this), end, weight); -#endif -} -inline Tensor & Tensor::fmod_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::fmod_(const_cast(*this), other); - break; - default: - AT_ERROR("fmod_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::fmod_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::fmod_(const_cast(*this), other); - break; - default: - AT_ERROR("fmod_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::remainder_(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::remainder_(const_cast(*this), other); - break; - default: - AT_ERROR("remainder_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder_", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::remainder_(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::remainder_(const_cast(*this), other); - break; - default: - AT_ERROR("remainder_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder_", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor & Tensor::addbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::addbmm_(const_cast(*this), batch1, batch2, beta, alpha); - break; - default: - AT_ERROR("addbmm_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addbmm_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2)), const_cast(*this), batch1, batch2, beta, alpha); -#endif -} -inline Tensor Tensor::addbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::addbmm(const_cast(*this), batch1, batch2, beta, alpha); - break; - default: - AT_ERROR("addbmm not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addbmm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2)), const_cast(*this), batch1, batch2, beta, alpha); -#endif -} -inline Tensor & Tensor::addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::addcdiv_(const_cast(*this), tensor1, tensor2, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcdiv_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2)), const_cast(*this), tensor1, tensor2, value); -#endif -} -inline Tensor & Tensor::random_(int64_t from, int64_t to, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::random_(const_cast(*this), from, to, generator); - break; - default: - AT_ERROR("random_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::random_", "from"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), from, to, generator); -#endif -} -inline Tensor & Tensor::random_(int64_t to, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::random_(const_cast(*this), to, generator); - break; - default: - AT_ERROR("random_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::random_", "to"}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), to, generator); -#endif -} -inline Tensor & Tensor::random_(Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::random_(const_cast(*this), generator); - break; - default: - AT_ERROR("random_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::random_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), generator); -#endif -} -inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::uniform_(const_cast(*this), from, to, generator); - break; - default: - AT_ERROR("uniform_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::uniform_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), from, to, generator); -#endif -} -inline Tensor & Tensor::normal_(double mean, double std, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::normal_(const_cast(*this), mean, std, generator); - break; - default: - AT_ERROR("normal_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::normal_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), mean, std, generator); -#endif -} -inline Tensor & Tensor::cauchy_(double median, double sigma, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::cauchy_(const_cast(*this), median, sigma, generator); - break; - default: - AT_ERROR("cauchy_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cauchy_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), median, sigma, generator); -#endif -} -inline Tensor & Tensor::log_normal_(double mean, double std, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::log_normal_(const_cast(*this), mean, std, generator); - break; - default: - AT_ERROR("log_normal_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log_normal_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), mean, std, generator); -#endif -} -inline Tensor & Tensor::exponential_(double lambd, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::exponential_(const_cast(*this), lambd, generator); - break; - default: - AT_ERROR("exponential_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::exponential_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), lambd, generator); -#endif -} -inline Tensor & Tensor::geometric_(double p, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::geometric_(const_cast(*this), p, generator); - break; - default: - AT_ERROR("geometric_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::geometric_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p, generator); -#endif -} -inline Tensor Tensor::diag(int64_t diagonal) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::diag(const_cast(*this), diagonal); - break; - default: - AT_ERROR("diag not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diag", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), diagonal); -#endif -} -inline Tensor Tensor::cross(const Tensor & other, c10::optional dim) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cross(const_cast(*this), other, dim); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cross", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, dim); -#endif -} -inline Tensor Tensor::triu(int64_t diagonal) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::triu(const_cast(*this), diagonal); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::triu", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), diagonal); -#endif -} -inline Tensor Tensor::tril(int64_t diagonal) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::tril(const_cast(*this), diagonal); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tril", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), diagonal); -#endif -} -inline Tensor Tensor::trace() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::trace(const_cast(*this)); - break; - default: - AT_ERROR("trace not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::trace", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::ne(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::ne(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::ne(const_cast(*this), other); - break; - default: - AT_ERROR("ne not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::ne(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::ne(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::ne(const_cast(*this), other); - break; - default: - AT_ERROR("ne not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::eq(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::eq(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::eq(const_cast(*this), other); - break; - default: - AT_ERROR("eq not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::eq(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::eq(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::eq(const_cast(*this), other); - break; - default: - AT_ERROR("eq not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::ge(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::ge(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::ge(const_cast(*this), other); - break; - default: - AT_ERROR("ge not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::ge(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::ge(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::ge(const_cast(*this), other); - break; - default: - AT_ERROR("ge not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::le(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::le(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::le(const_cast(*this), other); - break; - default: - AT_ERROR("le not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::le(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::le(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::le(const_cast(*this), other); - break; - default: - AT_ERROR("le not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::gt(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::gt(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::gt(const_cast(*this), other); - break; - default: - AT_ERROR("gt not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::gt(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::gt(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::gt(const_cast(*this), other); - break; - default: - AT_ERROR("gt not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::lt(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lt(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::lt(const_cast(*this), other); - break; - default: - AT_ERROR("lt not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::lt(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lt(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::lt(const_cast(*this), other); - break; - default: - AT_ERROR("lt not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::take(const Tensor & index) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::take(const_cast(*this), index); - break; - default: - AT_ERROR("take not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::take", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index)), const_cast(*this), index); -#endif -} -inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::index_select(const_cast(*this), dim, index); - break; - case Backend::SparseCPU: - return SparseCPUType::index_select(const_cast(*this), dim, index); - break; - default: - AT_ERROR("index_select not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_select", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index)), const_cast(*this), dim, index); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::index_select(Dimname dim, const Tensor & index) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::index_select(const_cast(*this), dim, index); -#else - static auto table = globalATenDispatch().getOpTable("aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index); -#endif -} -#endif -inline Tensor Tensor::masked_select(const Tensor & mask) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::masked_select(const_cast(*this), mask); - break; - default: - AT_ERROR("masked_select not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_select", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask)), const_cast(*this), mask); -#endif -} -inline Tensor Tensor::nonzero() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::nonzero(const_cast(*this)); - break; - default: - AT_ERROR("nonzero not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::nonzero", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline std::vector Tensor::nonzero_numpy() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::nonzero_numpy(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::nonzero_numpy", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::gather(const_cast(*this), dim, index, sparse_grad); - break; - default: - AT_ERROR("gather not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gather", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index)), const_cast(*this), dim, index, sparse_grad); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::gather(Dimname dim, const Tensor & index, bool sparse_grad) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::gather(const_cast(*this), dim, index, sparse_grad); -#else - static auto table = globalATenDispatch().getOpTable("aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, index, sparse_grad); -#endif -} -#endif -inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::addcmul(const_cast(*this), tensor1, tensor2, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcmul", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2)), const_cast(*this), tensor1, tensor2, value); -#endif -} -inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::addcmul_(const_cast(*this), tensor1, tensor2, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcmul_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2)), const_cast(*this), tensor1, tensor2, value); -#endif -} -inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::addcdiv(const_cast(*this), tensor1, tensor2, value); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcdiv", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2)), const_cast(*this), tensor1, tensor2, value); -#endif -} -inline std::tuple Tensor::lstsq(const Tensor & A) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lstsq(const_cast(*this), A); - break; - default: - AT_ERROR("lstsq not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lstsq", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, const Tensor &>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, A)), const_cast(*this), A); -#endif -} -inline std::tuple Tensor::triangular_solve(const Tensor & A, bool upper, bool transpose, bool unitriangular) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::triangular_solve(const_cast(*this), A, upper, transpose, unitriangular); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::triangular_solve", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, const Tensor &, bool, bool, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, A)), const_cast(*this), A, upper, transpose, unitriangular); -#endif -} -inline std::tuple Tensor::symeig(bool eigenvectors, bool upper) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::symeig(const_cast(*this), eigenvectors, upper); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::symeig", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, bool, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), eigenvectors, upper); -#endif -} -inline std::tuple Tensor::eig(bool eigenvectors) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::eig(const_cast(*this), eigenvectors); - break; - default: - AT_ERROR("eig not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eig", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), eigenvectors); -#endif -} -inline std::tuple Tensor::svd(bool some, bool compute_uv) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::svd(const_cast(*this), some, compute_uv); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::svd", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, bool, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), some, compute_uv); -#endif -} -inline Tensor Tensor::cholesky(bool upper) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cholesky(const_cast(*this), upper); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), upper); -#endif -} -inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::cholesky_solve(const_cast(*this), input2, upper); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky_solve", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, input2)), const_cast(*this), input2, upper); -#endif -} -inline std::tuple Tensor::solve(const Tensor & A) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::solve(const_cast(*this), A); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::solve", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, const Tensor &>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, A)), const_cast(*this), A); -#endif -} -inline Tensor Tensor::cholesky_inverse(bool upper) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::cholesky_inverse(const_cast(*this), upper); - break; - default: - AT_ERROR("cholesky_inverse not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky_inverse", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), upper); -#endif -} -inline std::tuple Tensor::qr(bool some) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::qr(const_cast(*this), some); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::qr", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), some); -#endif -} -inline std::tuple Tensor::geqrf() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::geqrf(const_cast(*this)); - break; - default: - AT_ERROR("geqrf not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::geqrf", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::orgqr(const Tensor & input2) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::orgqr(const_cast(*this), input2); - break; - default: - AT_ERROR("orgqr not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::orgqr", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, input2)), const_cast(*this), input2); -#endif -} -inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::ormqr(const_cast(*this), input2, input3, left, transpose); - break; - default: - AT_ERROR("ormqr not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ormqr", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, input2, input3)), const_cast(*this), input2, input3, left, transpose); -#endif -} -inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::lu_solve(const_cast(*this), LU_data, LU_pivots); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lu_solve", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, LU_data, LU_pivots)), const_cast(*this), LU_data, LU_pivots); -#endif -} -inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generator * generator) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::multinomial(const_cast(*this), num_samples, replacement, generator); - break; - default: - AT_ERROR("multinomial not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::multinomial", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), num_samples, replacement, generator); -#endif -} -inline Tensor Tensor::lgamma() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lgamma(const_cast(*this)); - break; - default: - AT_ERROR("lgamma not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lgamma", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::digamma() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::digamma(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::digamma", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::polygamma(int64_t n) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::polygamma(n, const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::polygamma", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), n, const_cast(*this)); -#endif -} -inline Tensor Tensor::erfinv() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::erfinv(const_cast(*this)); - break; - default: - AT_ERROR("erfinv not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfinv", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::erfinv_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::erfinv_(const_cast(*this)); - break; - default: - AT_ERROR("erfinv_ not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfinv_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::sign() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sign(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sign", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor & Tensor::sign_() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sign_(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sign_", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::dist(const Tensor & other, Scalar p) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::dist(const_cast(*this), other, p); - break; - default: - AT_ERROR("dist not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dist", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other, p); -#endif -} -inline Tensor Tensor::atan2(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::atan2(const_cast(*this), other); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan2", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::lerp(const Tensor & end, Scalar weight) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lerp(const_cast(*this), end, weight); - break; - default: - AT_ERROR("lerp not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end)), const_cast(*this), end, weight); -#endif -} -inline Tensor Tensor::lerp(const Tensor & end, const Tensor & weight) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::lerp(const_cast(*this), end, weight); - break; - default: - AT_ERROR("lerp not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end, weight)), const_cast(*this), end, weight); -#endif -} -inline Tensor Tensor::histc(int64_t bins, Scalar min, Scalar max) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::histc(const_cast(*this), bins, min, max); - break; - default: - AT_ERROR("histc not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::histc", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), bins, min, max); -#endif -} -inline Tensor Tensor::fmod(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::fmod(const_cast(*this), other); - break; - default: - AT_ERROR("fmod not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::fmod(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::fmod(const_cast(*this), other); - break; - default: - AT_ERROR("fmod not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::remainder(Scalar other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::remainder(const_cast(*this), other); - break; - default: - AT_ERROR("remainder not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder", "Scalar"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::remainder(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::remainder(const_cast(*this), other); - break; - default: - AT_ERROR("remainder not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder", "Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::min(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::min(const_cast(*this), other); - break; - default: - AT_ERROR("min not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min", "other"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::min() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::min(const_cast(*this)); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::min(const_cast(*this)); - break; - default: - AT_ERROR("min not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::max(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::max(const_cast(*this), other); - break; - default: - AT_ERROR("max not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max", "other"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::max() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::max(const_cast(*this)); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::max(const_cast(*this)); - break; - default: - AT_ERROR("max not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::median() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::median(const_cast(*this)); - break; - default: - AT_ERROR("median not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::median", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline std::tuple Tensor::sort(int64_t dim, bool descending) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::sort(const_cast(*this), dim, descending); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::sort(const_cast(*this), dim, descending); - break; - default: - AT_ERROR("sort not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sort", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, descending); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline std::tuple Tensor::sort(Dimname dim, bool descending) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::sort(const_cast(*this), dim, descending); -#else - static auto table = globalATenDispatch().getOpTable("aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)"); - return table->callUnboxed, const Tensor &, Dimname, bool>(const_cast(*this), dim, descending); -#endif -} -#endif -inline Tensor Tensor::argsort(int64_t dim, bool descending) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::argsort(const_cast(*this), dim, descending); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::argsort", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dim, descending); -#endif -} -#ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::argsort(Dimname dim, bool descending) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::argsort(const_cast(*this), dim, descending); -#else - static auto table = globalATenDispatch().getOpTable("aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor"); - return table->callUnboxed(const_cast(*this), dim, descending); -#endif -} -#endif -inline std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::topk(const_cast(*this), k, dim, largest, sorted); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::topk(const_cast(*this), k, dim, largest, sorted); - break; - default: - AT_ERROR("topk not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::topk", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly, const Tensor &, int64_t, int64_t, bool, bool>( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), k, dim, largest, sorted); -#endif -} -inline Tensor Tensor::all() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::all(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::all", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::any() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::any(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::any", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} -inline Tensor Tensor::renorm(Scalar p, int64_t dim, Scalar maxnorm) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::renorm(const_cast(*this), p, dim, maxnorm); - break; - default: - AT_ERROR("renorm not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::renorm", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), p, dim, maxnorm); -#endif -} -inline Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::unfold(const_cast(*this), dimension, size, step); - break; - default: - AT_ERROR("unfold not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unfold", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this), dimension, size, step); -#endif -} -inline bool Tensor::equal(const Tensor & other) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::equal(const_cast(*this), other); - break; - case Backend::QuantizedCPU: - return QuantizedCPUType::equal(const_cast(*this), other); - break; - default: - AT_ERROR("equal not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::equal", ""}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other)), const_cast(*this), other); -#endif -} -inline Tensor Tensor::pow(const Tensor & exponent) const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { - case Backend::CPU: - return CPUType::pow(const_cast(*this), exponent); - break; - default: - AT_ERROR("pow not implemented for ", at::toString(type_set())); - } -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow", "Tensor_Tensor"}).value(); - return c10::Dispatcher::singleton().callUnboxed( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, exponent)), const_cast(*this), exponent); -#endif -} -inline Tensor Tensor::alias() const { -#ifdef USE_STATIC_DISPATCH - at::AutoNonVariableTypeMode _var_guard(true); - return TypeDefault::alias(const_cast(*this)); -#else - static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::alias", ""}).value(); - return c10::Dispatcher::singleton().callUnboxedOnly( - op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast(*this)); -#endif -} - -inline caffe2::TypeMeta Tensor::dtype() const noexcept { - return impl_->dtype(); -} - -inline Layout Tensor::layout() const noexcept { - return impl_->layout(); -} - -inline Device Tensor::device() const { - return impl_->device(); -} - -inline int64_t Tensor::get_device() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->get_device(); -} - -inline int64_t get_device(Tensor self) { - return self.get_device(); -} - -inline bool Tensor::is_cuda() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_cuda(); -} - -#ifdef BUILD_NAMEDTENSOR -inline NamedTensorMeta* Tensor::get_named_tensor_meta() { - return static_cast(impl_->named_tensor_meta()); -} - -inline const NamedTensorMeta* Tensor::get_named_tensor_meta() const { - return static_cast(impl_->named_tensor_meta()); -} - -inline bool Tensor::has_names() const { - return impl::has_names(unsafeGetTensorImpl()); -} -#endif - -inline bool is_cuda(Tensor self) { - return self.is_cuda(); -} - -inline bool Tensor::is_hip() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_hip(); -} - -inline bool is_hip(Tensor self) { - return self.is_hip(); -} - -inline bool Tensor::is_sparse() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_sparse(); -} - -inline bool is_sparse(Tensor self) { - return self.is_sparse(); -} - -inline bool Tensor::is_mkldnn() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_mkldnn(); -} - -inline bool is_mkldnn(Tensor self) { - return self.is_mkldnn(); -} - -inline bool Tensor::is_quantized() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_quantized(); -} - -inline bool is_quantized(Tensor self) { - return self.is_quantized(); -} - -#define DEFINE_CAST(T, name) \ - template <> \ - inline T* Tensor::data_ptr() const { \ - TORCH_CHECK( \ - scalar_type() == ScalarType::name, \ - "expected scalar type ", \ - #name, \ - " but found ", \ - c10::toString(scalar_type())); \ - return static_cast(this->unsafeGetTensorImpl()->data()); \ - } - -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) -AT_FORALL_QINT_TYPES(DEFINE_CAST) -#undef DEFINE_CAST - -#define DEFINE_ITEM(T, name) \ - template <> \ - inline T Tensor::item() const { \ - return item().to##name(); \ - } - -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ITEM) -#undef DEFINE_ITEM - -} //namespace at diff --git a/aten/src/ATen/core/context_base.cpp b/aten/src/ATen/core/context_base.cpp deleted file mode 100644 index f91764031f84d..0000000000000 --- a/aten/src/ATen/core/context_base.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -#include - -namespace at { - -C10_DEFINE_TYPED_REGISTRY( - ContextRegistry, - at::DeviceType, - at::BaseContext, - std::unique_ptr, - at::Device); - -} // namespace at - -namespace caffe2 { - -// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h - -} // namespace caffe2 diff --git a/aten/src/ATen/core/context_base.h b/aten/src/ATen/core/context_base.h deleted file mode 100644 index 4ea513857dc78..0000000000000 --- a/aten/src/ATen/core/context_base.h +++ /dev/null @@ -1,164 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace caffe2 { -class Event; - -} // namespace caffe2 -namespace at { - -class BaseContext; - -/** - * Virtual interface for the Context class in Caffe2. - * - * A Context defines all the necessities to run an operator on a specific - * device. Specific Context classes needs to implement all the pure virtual - * functions in the BaseContext class. - * TODO: add docs after this is finalized. - */ -class CAFFE2_API BaseContext { - public: - virtual ~BaseContext() noexcept {} - - virtual Device device() const = 0; - - /* Sorry for the naming, will get rid of this in future diff */ - virtual DeviceType device_type() const = 0; - - virtual void SwitchToDevice(int /*stream_id*/) = 0; - - inline void SwitchToDevice() { - SwitchToDevice(0); - } - - virtual void WaitEvent(const caffe2::Event& ev) = 0; - - virtual void Record(caffe2::Event* ev, const char* err_msg = nullptr) - const = 0; - - virtual void FinishDeviceComputation() = 0; - - // This used to be arbitrary cross-device copy, but it turns out everyone - // did direct CPU-X copy, so we just make three functions for it (to avoid - // double dispatch). This will get obsoleted by C10. where copies - // will be proper operators (and get to rely on multiple dispatch there.) - virtual void CopyBytesSameDevice( - size_t nbytes, - const void* src, - void* dst) = 0; - - virtual void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) = 0; - - virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0; - - template - inline void CopySameDevice(size_t n, const T* src, T* dst) { - static_assert( - std::is_fundamental::value, - "CopySameDevice requires fundamental types"); - CopyBytesSameDevice( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - template - inline void CopyFromCPU(size_t n, const T* src, T* dst) { - static_assert( - std::is_fundamental::value, - "CopyFromCPU requires fundamental types"); - CopyBytesFromCPU( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - template - inline void CopyToCPU(size_t n, const T* src, T* dst) { - static_assert( - std::is_fundamental::value, "CopyToCPU requires fundamental types"); - CopyBytesToCPU( - n * sizeof(T), static_cast(src), static_cast(dst)); - } - - virtual bool SupportsNonFundamentalTypes() const { - return false; - } - - inline void EnforceMetaCopyOK() { - AT_ASSERTM( - SupportsNonFundamentalTypes(), "Context requires fundamental types"); - } - - void CopyItemsSameDevice( - const caffe2::TypeMeta& meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesSameDevice(n * meta.itemsize(), src, dst); - } - } - - void CopyItemsFromCPU( - const caffe2::TypeMeta& meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesFromCPU(n * meta.itemsize(), src, dst); - } - } - - void CopyItemsToCPU( - const caffe2::TypeMeta& meta, - size_t n, - const void* src, - void* dst) { - if (meta.copy()) { - EnforceMetaCopyOK(); - meta.copy()(src, dst, n); - } else { - CopyBytesToCPU(n * meta.itemsize(), src, dst); - } - } -}; - -// Context constructor registry -C10_DECLARE_TYPED_REGISTRY( - ContextRegistry, - at::DeviceType, - at::BaseContext, - std::unique_ptr, - at::Device); - -#define REGISTER_CONTEXT(type, ...) \ - C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__) - -inline std::unique_ptr CreateContext( - const at::Device& device) { - return at::ContextRegistry()->Create(device.type(), device); -} - -} // namespace at - -namespace caffe2 { - -using at::BaseContext; -using at::CreateContext; -} // namespace caffe2 diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index e5837680ff2f4..f065e6221e341 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -117,6 +117,7 @@ namespace c10 { _(prim, CallFunction) \ _(prim, CallMethod) \ _(prim, LoopContinuation) \ + _(prim, annotate) \ _(aten, append) \ _(aten, item) \ _(aten, format) \ diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index dc40d61a5160f..a26371b559d42 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1458,6 +1458,28 @@ struct CAFFE2_API ClassType : public NamedType { TypePtr type, bool is_parameter = false); + // Add attribute \p NAME if it doesn't exist or verify that it has a + // compatible type otherwise. + size_t addOrCheckAttribute( + const std::string& name, + TypePtr ty, + bool is_parameter = false) { + auto slot_idx = findAttributeSlot(name); + if (!slot_idx) { + return addAttribute(name, ty, is_parameter); + } + + TORCH_CHECK( + is_parameter == this->is_parameter(*slot_idx), + "Parameter field mismatch for the field '", + name, + "'"); + TypePtr atype = getAttribute(*slot_idx); + TORCH_CHECK(ty->isSubtypeOf(atype)); + return *slot_idx; + } + + at::ArrayRef attributeNames() const { return attributeNames_; } diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 36b7bd99acffc..7f831e5cdf785 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -61,7 +61,6 @@ namespace at { namespace cuda { // list above. // // HIP doesn't have -// cuOccupancyMaxActiveBlocksPerMultiprocessor // cuGetErrorString (maps to non-functional hipGetErrorString___) #define AT_FORALL_NVRTC(_) \ @@ -72,6 +71,7 @@ namespace at { namespace cuda { _(nvrtcGetPTX) \ _(cuModuleLoadData) \ _(cuModuleGetFunction) \ + _(cuOccupancyMaxActiveBlocksPerMultiprocessor) \ _(nvrtcGetErrorString) \ _(nvrtcGetProgramLogSize) \ _(nvrtcGetProgramLog) \ diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index c3ee9547133e4..e3baa023ab837 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -239,17 +239,10 @@ def TypedDict(name, attrs, total=True): # type: ignore # In order to rely on the linker to strip unused ops, it requires us to dispatch statically # in Functions.h and TensorMethods.h. -# -# NB: The default body also needs to apply a variable guard, as in some -# situations what we think is a default body actually does have an -# explicit derivative, and thereby would have gotten unwrapped by -# the time you get to the implementation. STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\ -at::AutoNonVariableTypeMode _var_guard(true); ${return_call} TypeDefault::${native_type_method_dispatch}(${native_arguments}); """) STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\ -at::AutoNonVariableTypeMode _var_guard(true); switch(tensorTypeIdToBackend(impl::dispatchTypeId(${type_set}))) { ${static_dispatch_function_switches} default: @@ -1286,9 +1279,7 @@ def gen_namespace_function(option, multidispatch_tensors): return FunctionCode(definition=fn_definition, declaration=fn_declaration) # Emit #ifdef BUILD_NAMEDTENSOR macros for any code generated here - # that is sent to top_env. This is because some of this code (Type.h, - # TensorBody.h, TensorMethods.h) is checked into the repo and must be - # the same regardless of BUILD_NAMEDTENSOR status. + # that is sent to top_env. is_named_tensor_only = (has_named_tensor_formals(formals) or option['api_name'] == 'align_tensors' or option['api_name'] == 'align_as') diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index cff113ebb0211..f99e5828c03ae 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -45,12 +45,17 @@ action='store_true', help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly') options = parser.parse_args() -gen_to_source = os.environ.get('GEN_TO_SOURCE') # update source directly as part of gen -if not gen_to_source: - core_install_dir = os.path.join(options.install_dir, 'core_tmp') if options.install_dir is not None else None -else: - core_install_dir = os.path.join(options.source_path, 'core') - +# NB: It is mandatory to NOT use os.path.join here, as the install directory +# will eventually be ingested by cmake, which does not respect Windows style +# path slashes. If you switch this to use os.path.join, you'll get an error +# like: +# +# Syntax error in cmake code when parsing string +# +# C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h +# +# Invalid character escape '\c'. +core_install_dir = options.install_dir + '/core' if options.install_dir is not None else None if options.install_dir is not None and not os.path.exists(options.install_dir): os.makedirs(options.install_dir) if core_install_dir is not None and not os.path.exists(core_install_dir): @@ -327,7 +332,7 @@ def iterate_types(): # so that the script runs quickly when we are just querying the # outputs def declare_outputs(): - core_files = ['TensorBody.h', 'TensorMethods.h'] + core_files = ['TensorBody.h', 'TensorMethods.h', 'OpsAlreadyMovedToC10.cpp'] for f in core_files: core_file_manager.will_write(f) files = ['Declarations.yaml', 'TypeDefault.cpp', 'TypeDefault.h', @@ -359,36 +364,6 @@ def filter_by_extension(files, *extensions): return filtered_files -# because EOL may not be LF(\n) on some environment (e.g. Windows), -# normalize EOL from CRLF/CR to LF and compare both files. -def cmpfiles_with_eol_normalization(a, b, names): - results = ([], [], []) # match, mismatch, error - for x in names: - try: - with open(os.path.join(a, x)) as f: - ax = f.read().replace('\r\n', '\n').replace('\r', '\n') - with open(os.path.join(b, x)) as f: - bx = f.read().replace('\r\n', '\n').replace('\r', '\n') - if ax == bx: - results[0].append(x) - else: - results[1].append(x) - import difflib - import sys - d = difflib.Differ() - sys.stdout.write('-' * 80 + '\n') - sys.stdout.write('x={}, a={}, b={}\n'.format(x, a, b)) - for i, line in enumerate(list(d.compare(ax.splitlines(), bx.splitlines()))): - if line[:2] != ' ': - sys.stdout.write('{:5d}: {}\n'.format(i, line)) - sys.stdout.write('-' * 80 + '\n') - sys.stdout.write(ax) - sys.stdout.write('-' * 80 + '\n') - except OSError: - results[2].append(x) - return results - - def is_namedtensor_only_decl(decl): if 'Dimname' in decl['schema_string']: return True @@ -448,21 +423,6 @@ def generate_outputs(): file_manager.check_all_files_written() cuda_file_manager.check_all_files_written() - # check that generated files match source files - core_source_path = os.path.join(options.source_path, 'core') - match, mismatch, errors = cmpfiles_with_eol_normalization(core_install_dir, core_source_path, core_files.keys()) - if errors: - raise RuntimeError("Error while trying to compare source and generated files for {}. " - "Source directory: {}. Generated directory: {}." - .format(errors, core_source_path, core_install_dir)) - if mismatch: - file_component = '{}'.format(','.join(mismatch)) - if len(mismatch) > 1: - file_component = '{' + file_component + '}' - update_cmd = "cp {}/{} {}".format(core_install_dir, file_component, core_source_path) - raise RuntimeError("Source files: {} did not match generated files. To update the source files, " - "set environment variable GEN_TO_SOURCE or run \"{}\"".format(mismatch, update_cmd)) - declare_outputs() if options.output_dependencies is not None: file_manager.write_outputs(options.output_dependencies) diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index 7fa2bcbca01e5..8ad73a9be7044 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -44,7 +44,7 @@ namespace c10 { namespace hip { // we switch PyTorch to calling a HIP a HIP. // // When you add a new MasqueradingAsCUDA class/function, you need to -// also update the rewrite rules in tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py // // // diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index cb965b6f77959..a8509851974f2 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -711,6 +711,14 @@ Tensor zeros_like(const Tensor& self, const TensorOptions& options) { return native::zeros(self.sizes(), options); } +Tensor new_zeros( + const Tensor& self, + IntArrayRef size, + const TensorOptions& options + ) { + return at::zeros(size, self.options().merge_in(options)); +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ bartlett_window ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tensor bartlett_window(int64_t window_length, const TensorOptions& options) { @@ -867,8 +875,20 @@ Tensor from_file(std::string filename, c10::optional shared, c10::optional // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ clone ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Tensor clone(const Tensor& src) { - auto self = at::empty_like(src); +Tensor clone(const Tensor& src, c10::optional optional_memory_format) { + auto memory_format = + optional_memory_format.value_or(MemoryFormat::Contiguous); + if (memory_format == MemoryFormat::Preserve) { + if (src.is_non_overlapping_and_dense()) { + // Copy all strides + auto self = at::empty_strided(src.sizes(), src.strides(), src.options()); + self.copy_(src); + return self; + } else { + memory_format = src.suggest_memory_format(); + } + } + auto self = at::empty_like(src, src.options(), memory_format); self.copy_(src); return self; } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 15e67644b3ec2..3b26cd85f26c3 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -77,6 +77,10 @@ Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(r Tensor log(const Tensor& self) { return unary_op_impl(self, at::log_out); } Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); } +Tensor& log10_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log10_stub); } +Tensor log10(const Tensor& self) { return unary_op_impl(self, at::log10_out); } +Tensor& log10_(Tensor& self) { return unary_op_impl_(self, at::log10_out); } + Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, round_stub); } Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); } Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); } @@ -269,7 +273,6 @@ IMPLEMENT_UNARY_OP_VEC(erfc) IMPLEMENT_UNARY_OP_VEC_CUDA(erfinv) IMPLEMENT_UNARY_OP_VEC(exp) IMPLEMENT_UNARY_OP_VEC(frac) -IMPLEMENT_UNARY_OP_VEC(log10) IMPLEMENT_UNARY_OP_VEC(log1p) IMPLEMENT_UNARY_OP_VEC(log2) IMPLEMENT_UNARY_OP_VEC(reciprocal) diff --git a/aten/src/ATen/native/cpu/IsContiguous.h b/aten/src/ATen/native/cpu/IsContiguous.h index 6392746a8013a..8a24ce0b7f24c 100644 --- a/aten/src/ATen/native/cpu/IsContiguous.h +++ b/aten/src/ATen/native/cpu/IsContiguous.h @@ -5,34 +5,58 @@ namespace at { namespace native { namespace { // n: number of function arguments (arity) // traits: function_traits (see FunctionTraits.h) // s: index of scalar argument or -1 -template +template struct IsContiguous { static bool eval(const int64_t* strides) { using type = typename traits::template arg::type; - return strides[n] == (s == n ? 0 : sizeof(type)) && - IsContiguous::eval(strides); + return strides[stride_index] == (s == n ? 0 : sizeof(type)) && + IsContiguous::eval(strides); } }; +// will be called when there is an output exists template -struct IsContiguous<0, traits, s> { +struct IsContiguous<0, 0, traits, s> { static bool eval(const int64_t* strides) { return strides[0] == sizeof(typename traits::result_type); } }; +// will be called when there is no output +template +struct IsContiguous<0, -1, traits, s> { + static bool eval(const int64_t* strides) { + return true; + } +}; + // output and all inputs are contiguous -template +template ::value>::type* = nullptr> static inline bool is_contiguous(const int64_t* strides) { - return IsContiguous::eval(strides); + return IsContiguous::eval(strides); +} + +template ::value>::type* = nullptr> +static inline bool is_contiguous(const int64_t* strides) { + return IsContiguous::eval(strides); } // input at `s` is scalar (stride 0); output and other inputs are contiguous // NB: output is typically at strides[0] so first input corresponds to s=1 -template +template ::value>::type* = nullptr> +static inline bool is_contiguous_scalar(const int64_t* strides) { + static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); + return IsContiguous::eval(strides); +} + +template ::value>::type* = nullptr> static inline bool is_contiguous_scalar(const int64_t* strides) { static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); - return IsContiguous::eval(strides); + return IsContiguous::eval(strides); } }}} diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index e4bc9bc492b49..5d2ae2020c9c0 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -80,13 +80,40 @@ dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& o return dereference_vec_impl(data, opt_scalar, S, i, Indices{}); } +template ::result_type>::value>::type* = nullptr> +static inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t op) { + using traits = function_traits; + using result_type = typename traits::result_type; + for (; i < n; i++) { + result_type* out_ptr = (result_type*)(data[0] + i * strides[0]); + *out_ptr = c10::guts::apply(op, dereference( + &data[1], + &strides[1], + i)); + } +} + +template ::result_type>::value>::type* = nullptr> +static inline void +execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t op) { + using traits = function_traits; + for (; i < n; i++) { + c10::guts::apply(op, dereference( + &data[0], + &strides[0], + i)); + } +} + // Basic loop operation (one output, N inputs). May be auto-vectorized // by the compiler. Supports inputs and outputs of different types. template static inline void basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t op) { using traits = function_traits; - using result_type = typename traits::result_type; constexpr int ntensors = traits::arity + 1; // Copying strides to temporary array helps auto vectorization in older GCC @@ -96,13 +123,7 @@ basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_ strides[arg] = strides_[arg]; } - for (; i < n; i++) { - result_type* out_ptr = (result_type*)(data[0] + i * strides[0]); - *out_ptr = c10::guts::apply(op, dereference( - &data[1], - &strides[1], - i)); - } + execute_op(data, strides, i, n, op); } // Explicitly vectorized loop implementation. All inputs and outputs must be @@ -205,7 +226,8 @@ void cpu_kernel_vec(TensorIterator& iter, func_t op, vec_func_t vop) { template void cpu_serial_kernel(TensorIterator& iter, func_t op) { using traits = function_traits; - TORCH_INTERNAL_ASSERT(iter.ntensors() >= traits::arity + 1); + TORCH_INTERNAL_ASSERT((std::is_void::value && + iter.noutputs() == 0 && iter.ntensors() == traits::arity) || (iter.ntensors() >= traits::arity + 1)); iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) { if (is_contiguous(strides)) { @@ -217,6 +239,7 @@ void cpu_serial_kernel(TensorIterator& iter, func_t op) { }); } }, {0, iter.numel()}); + iter.cast_outputs(); } }}} // namespace at::native:: diff --git a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp index e18bfd3e05808..61930831796d0 100644 --- a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp +++ b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp @@ -74,7 +74,6 @@ IMPLEMENT_UNARY_OP_PREQUEL(erf) IMPLEMENT_UNARY_OP_PREQUEL(erfc) IMPLEMENT_UNARY_OP_PREQUEL(exp) IMPLEMENT_UNARY_OP_PREQUEL(frac) -IMPLEMENT_UNARY_OP_PREQUEL(log10) IMPLEMENT_UNARY_OP_PREQUEL(log1p) IMPLEMENT_UNARY_OP_PREQUEL(log2) IMPLEMENT_UNARY_OP_PREQUEL(reciprocal) diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index ad77d95a86b4d..46ffdcaa27b4b 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -123,7 +123,11 @@ void SpatialSoftMax_getLaunchSizes( smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t); int max_active_blocks; #ifdef __HIP_PLATFORM_HCC__ - max_active_blocks = 16; + // XXX HIP function signature is not compatible yet. + uint32_t max_blocks; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks, + k, block_threads, smem_size); + max_active_blocks = max_blocks; #else cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, k, block_threads, smem_size); diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 9c6695d32c3d6..54b922f69064f 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -65,6 +65,14 @@ void log_kernel_cuda(TensorIterator& iter) { }); } +void log10_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "log10_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::log10(a); + }); + }); +} + void neg_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, iter.dtype(), "neg_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { @@ -180,6 +188,7 @@ REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda); REGISTER_DISPATCH(expm1_stub, &expm1_kernel_cuda); REGISTER_DISPATCH(floor_stub, &floor_kernel_cuda); REGISTER_DISPATCH(log_stub, &log_kernel_cuda); +REGISTER_DISPATCH(log10_stub, &log10_kernel_cuda); REGISTER_DISPATCH(neg_stub, &neg_kernel_cuda); REGISTER_DISPATCH(round_stub, &round_kernel_cuda); REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda); diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp index 1567df04eca67..b45784dc2e96a 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.cpp +++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp @@ -16,7 +16,7 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) { AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support"); } -Tensor mkldnn_clone(const Tensor& self) { +Tensor mkldnn_clone(const Tensor& self, c10::optional optional_memory_format) { AT_ERROR("mkldnn_clone: ATen not compiled with MKLDNN support"); } @@ -54,7 +54,11 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) { return new_with_itensor_mkldnn(std::move(y), self.options()); } -Tensor mkldnn_clone(const Tensor& self) { +Tensor mkldnn_clone(const Tensor& self, c10::optional optional_memory_format) { + TORCH_CHECK( + !optional_memory_format.has_value(), + "unsupported memory format option ", + optional_memory_format.value()); ideep::tensor& src = itensor_from_mkldnn(self); ideep::tensor dst; ideep::direct_copy::compute(src, dst); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5c738c9bc282f..0799f015bbf55 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -66,7 +66,7 @@ supports_named_tensor: True - func: align_to(Tensor(a) self, DimnameList names) -> Tensor(a) - variants: function, method + variants: method supports_named_tensor: True - func: align_as(Tensor self, Tensor other) -> Tensor @@ -1027,6 +1027,9 @@ - func: new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor variants: method +- func: new_zeros(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + variants: method + # other overrides are to provide a more helpful error message that dtype is required - func: _empty_affine_quantized(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor dispatch: @@ -1539,15 +1542,12 @@ use_c10_dispatcher: unboxed_only supports_named_tensor: True variants: function, method - dispatch: - CPU: _log10__cpu - CUDA: _log10__cuda - func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True dispatch: - CPU: _log10_out_cpu - CUDA: _log10_out_cuda + CPU: log10_out + CUDA: log10_out - func: log1p(Tensor self) -> Tensor use_c10_dispatcher: full @@ -3108,8 +3108,7 @@ - func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) variants: function -- func: clone(Tensor self) -> Tensor - use_c10_dispatcher: full +- func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor variants: function, method dispatch: CPU: clone diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp index e3deeac81362a..0301868befb24 100644 --- a/aten/src/ATen/native/quantized/QTensor.cpp +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -155,13 +155,29 @@ Tensor& set_quantizer_(Tensor& self, ConstQuantizerPtr quantizer) { return self; } -Tensor quantized_clone(const Tensor& self) { +Tensor quantized_clone(const Tensor& self, c10::optional optional_memory_format) { // TODO: add per channel support TORCH_INTERNAL_ASSERT( self.qscheme() == at::kPerTensorAffine, "clone for quantized Tensor only works for PerTensorAffine scheme right now"); + + auto memory_format = + optional_memory_format.value_or(MemoryFormat::Contiguous); + + // TODO: To support all features of MemoryFormat::Preserve we need to add + // _empty_affine_quantized_strided function and use it similarly to + // Tensor clone(const Tensor& src, c10::optional optional_memory_format) + // if (self.is_non_overlapping_and_dense()) -> _empty_affine_quantized_strided + if (memory_format == MemoryFormat::Preserve) { + memory_format = self.suggest_memory_format(); + } + Tensor dst = at::_empty_affine_quantized( - self.sizes(), self.options(), self.q_scale(), self.q_zero_point()); + self.sizes(), + self.options(), + self.q_scale(), + self.q_zero_point(), + memory_format); at::native::copy_(dst, self, false); diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 208f5b10581bf..48a973bc946c2 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -360,7 +360,8 @@ class QConv2dInt8 final : public c10::OperatorKernel { kernel_zp, MemoryFormat::ChannelsLast); auto* qnnp_w_data = qnnp_weight.data_ptr(); - for (int i = 0; i < weight_contig.numel(); ++i) { + auto wt_numel = weight_contig.numel(); + for (int i = 0; i < wt_numel; ++i) { qnnp_w_data[i] = static_cast(w_data[i] + 128); } // Original bias was float, so we requantize it here. diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index d90871421c20d..70c634d65d8eb 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -218,7 +218,8 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel { weight.q_scale(), weight_zp); auto* qnnp_w_data = qnnp_weight.data_ptr(); - for (int i = 0; i < weight_contig.numel(); ++i) { + auto wt_numel = weight_contig.numel(); + for (int i = 0; i < wt_numel; ++i) { qnnp_w_data[i] = static_cast(w_data[i] + 128); } // We set the pre-packed conv weights to nullptr below as we call pre-pack diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 55db78e4b3026..08ab39cfb9642 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -249,7 +249,8 @@ class QLinearInt8 final : public torch::OperatorKernel { kernel_scale, kernel_zp); auto* qnnp_w_data = qnnp_weight.data_ptr(); - for (int i = 0; i < weight_contig.numel(); ++i) { + auto wt_numel = weight_contig.numel(); + for (int i = 0; i < wt_numel; ++i) { qnnp_w_data[i] = static_cast(w_data[i] + 128); } // Original bias was float, so we requantize it here. diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index bcd0b690b9f63..ea6db888fd31d 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -163,7 +163,8 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel { weight.q_scale(), weight_zp); auto* qnnp_w_data = qnnp_weight.data_ptr(); - for (int i = 0; i < weight_contig.numel(); ++i) { + auto wt_numel = weight_contig.numel(); + for (int i = 0; i < wt_numel; ++i) { qnnp_w_data[i] = static_cast(inp_data[i] + 128); } initQNNPACK(); diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 418f403da7e47..7185acbe666ff 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -255,7 +255,11 @@ Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, A // NB: Deleted newWithSizeNd variants -SparseTensor clone_sparse(const SparseTensor& self) { +SparseTensor clone_sparse(const SparseTensor& self, c10::optional optional_memory_format) { + TORCH_CHECK( + !optional_memory_format.has_value(), + "unsupported memory format option ", + optional_memory_format.value()); SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(), self.options()); copy_into_sparse(other, self._indices(), self._values(), true); return other._coalesced_(self.is_coalesced()); diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index 69956aaba0729..5a1a7a94d861d 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -297,7 +297,8 @@ Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zer } #endif auto qdata = qtensor.data_ptr(); - for (int i = 0; i < rtensor.numel(); ++i) { + auto numel = rtensor.numel(); + for (int i = 0; i < numel; ++i) { qdata[i] = quantize_val(scale, zero_point, rdata[i]); } return qtensor; @@ -318,7 +319,8 @@ Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t z checkZeroPoint(fn_name, zero_point); const auto* qd = qtensor.data_ptr(); float* rd = rtensor.data_ptr(); - for (auto i = 0; i < qtensor.numel(); ++i) { + auto numel = qtensor.numel(); + for (auto i = 0; i < numel; ++i) { rd[i] = dequantize_val(scale, zero_point, qd[i]); } return rtensor; diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index dc7273d126128..f8a0e6f2aa4a6 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -186,10 +186,15 @@ class CAFFE2_API Tensor { int64_t ndimension() const { return dim(); } + bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { return impl_->is_contiguous(memory_format); } + bool is_non_overlapping_and_dense() const { + return impl_->is_non_overlapping_and_dense(); + } + at::MemoryFormat suggest_memory_format() const { if (impl_->is_strides_like_channels_last()) { return at::MemoryFormat::ChannelsLast; diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index 67640c08d43dd..da3f8fe1a5dbd 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -11,7 +11,6 @@ #include #include #include -#include #ifdef USE_STATIC_DISPATCH #include diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 0455ce9cd8035..80fdf4fc0f7ef 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -29,8 +29,7 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_iterator_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_generator_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/variant_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/pow_test.cpp) list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu diff --git a/aten/src/ATen/test/tensor_iterator_test.cpp b/aten/src/ATen/test/tensor_iterator_test.cpp index ef4d17a2f8b3f..6939d32ac7686 100644 --- a/aten/src/ATen/test/tensor_iterator_test.cpp +++ b/aten/src/ATen/test/tensor_iterator_test.cpp @@ -61,17 +61,41 @@ TEST(TensorIteratorTest, SerialLoopUnary_##name) { \ ASSERT_ANY_THROW(out.equal(expected)); \ } +#define NO_OUTPUT_UNARY_TEST_ITER_FOR_TYPE(ctype,name) \ +TEST(TensorIteratorTest, SerialLoopUnaryNoOutput_##name) { \ + auto in = random_tensor_for_type(k##name); \ + auto iter = at::TensorIterator(); \ + iter.add_input(in); \ + iter.build(); \ + int64_t acc = 0; \ + at::native::cpu_serial_kernel(iter, [&](ctype a) -> void { acc++; }); \ + EXPECT_TRUE(acc == in.numel()); \ +} + #define BINARY_TEST_ITER_FOR_TYPE(ctype,name) \ TEST(TensorIteratorTest, SerialLoopBinary_##name) { \ Tensor out; \ auto in1 = random_tensor_for_type(k##name); \ auto in2 = random_tensor_for_type(k##name); \ auto expected = in1.add(in2); \ - auto iter = TensorIterator::binary_op(out, in1, in2); \ + auto iter = TensorIterator::binary_op(out, in1, in2); \ at::native::cpu_serial_kernel(iter, [=](ctype a, ctype b) -> int { return a + b; }); \ ASSERT_ANY_THROW(out.equal(expected)); \ } +#define NO_OUTPUT_BINARY_TEST_ITER_FOR_TYPE(ctype,name) \ +TEST(TensorIteratorTest, SerialLoopBinaryNoOutput_##name) { \ + auto in1 = random_tensor_for_type(k##name); \ + auto in2 = random_tensor_for_type(k##name); \ + auto iter = at::TensorIterator(); \ + iter.add_input(in1); \ + iter.add_input(in2); \ + iter.build(); \ + int64_t acc = 0; \ + at::native::cpu_serial_kernel(iter, [&](ctype a, ctype b) -> void { acc++; }); \ + EXPECT_TRUE(acc == in1.numel()); \ +} + #define POINTWISE_TEST_ITER_FOR_TYPE(ctype,name) \ TEST(TensorIteratorTest, SerialLoopPointwise_##name) { \ Tensor out; \ @@ -89,6 +113,21 @@ TEST(TensorIteratorTest, SerialLoopPointwise_##name) { ASSERT_ANY_THROW(out.equal(expected)); \ } +#define NO_OUTPUT_POINTWISE_TEST_ITER_FOR_TYPE(ctype,name) \ +TEST(TensorIteratorTest, SerialLoopPoinwiseNoOutput_##name) { \ + auto in1 = random_tensor_for_type(k##name); \ + auto in2 = random_tensor_for_type(k##name); \ + auto in3 = random_tensor_for_type(k##name); \ + auto iter = at::TensorIterator(); \ + iter.add_input(in1); \ + iter.add_input(in2); \ + iter.add_input(in3); \ + iter.build(); \ + int64_t acc = 0; \ + at::native::cpu_serial_kernel(iter, [&](ctype a, ctype b, ctype c) -> void { acc++; }); \ + EXPECT_TRUE(acc == in1.numel()); \ +} + // The alternative way to calculate a < b is (b - a).clamp(0).toBool() // To prevent an overflow in subtraction (b - a) for unsigned types(unit, bool) // we will convert in to int first @@ -112,6 +151,9 @@ TEST(TensorIteratorTest, ComparisonLoopBinary_##name) { AT_FORALL_SCALAR_TYPES(UNARY_TEST_ITER_FOR_TYPE) AT_FORALL_SCALAR_TYPES(BINARY_TEST_ITER_FOR_TYPE) AT_FORALL_SCALAR_TYPES(POINTWISE_TEST_ITER_FOR_TYPE) +AT_FORALL_SCALAR_TYPES(NO_OUTPUT_UNARY_TEST_ITER_FOR_TYPE) +AT_FORALL_SCALAR_TYPES(NO_OUTPUT_BINARY_TEST_ITER_FOR_TYPE) +AT_FORALL_SCALAR_TYPES(NO_OUTPUT_POINTWISE_TEST_ITER_FOR_TYPE) AT_FORALL_SCALAR_TYPES_AND(Bool, COMPARISON_TEST_ITER_FOR_TYPE) TEST(TensorIteratorTest, SerialLoopSingleThread) { @@ -172,3 +214,4 @@ TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) { iter.compute_common_dtype_only_for_inputs(); ASSERT_ANY_THROW(iter.build()); } + diff --git a/aten/src/ATen/test/test_parallel.cpp b/aten/src/ATen/test/test_parallel.cpp index 7c1072c672252..a6579ae5f0836 100644 --- a/aten/src/ATen/test/test_parallel.cpp +++ b/aten/src/ATen/test/test_parallel.cpp @@ -69,3 +69,12 @@ TEST(TestParallel, IntraOpLaunchFuture) { ASSERT_TRUE(v1 == 1 && v2 == 2); } + +TEST(TestParallel, MultipleSetNumThreadsCalls) { +#if !AT_PARALLEL_NATIVE + set_num_threads(5); + ASSERT_TRUE(get_num_threads() == 5); + set_num_threads(10); + ASSERT_TRUE(get_num_threads() == 10); +#endif +} diff --git a/aten/src/ATen/test/variant_test.cpp b/aten/src/ATen/test/variant_test.cpp deleted file mode 100644 index da741e15bb08a..0000000000000 --- a/aten/src/ATen/test/variant_test.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include - -#include - -namespace testns { - -namespace enumtype { - // NOTE: We need to provide the default constructor for each struct, - // otherwise Clang 3.8 would complain: - // ``` - // error: default initialization of an object of const type 'const enumtype::Enum1' - // without a user-provided default constructor - // ``` - struct Enum1 { Enum1() {} }; - struct Enum2 { Enum2() {} }; - struct Enum3 { Enum3() {} }; -} // namespace enumtype - -const enumtype::Enum1 kEnum1; -const enumtype::Enum2 kEnum2; -const enumtype::Enum3 kEnum3; - -} // namespace testns - -std::string func(c10::variant v) { - if (c10::get_if(&v)) { - return "Enum1"; - } else if (c10::get_if(&v)) { - return "Enum2"; - } else if (c10::get_if(&v)) { - return "Enum3"; - } else { - return "Unsupported enum"; - } -} - -TEST(VariantTest, Basic) { - ASSERT_EQ(func(testns::kEnum1), "Enum1"); - ASSERT_EQ(func(testns::kEnum2), "Enum2"); - ASSERT_EQ(func(testns::kEnum3), "Enum3"); -} diff --git a/aten/src/TH/THGeneral.cpp b/aten/src/TH/THGeneral.cpp index f921673549f88..060b1e6cebe55 100644 --- a/aten/src/TH/THGeneral.cpp +++ b/aten/src/TH/THGeneral.cpp @@ -188,11 +188,6 @@ void THFree(void *ptr) c10::free_cpu(ptr); } -double THLog10(const double x) -{ - return log10(x); -} - double THLog1p(const double x) { #if (defined(_MSC_VER) || defined(__MINGW32__)) diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 600ae40057a7b..97b88cf3cdd91 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -151,7 +151,6 @@ TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) TH_API void THTensor_(sigmoid)(THTensor *r_, THTensor *t); -TH_API void THTensor_(log10)(THTensor *r_, THTensor *t); TH_API void THTensor_(log1p)(THTensor *r_, THTensor *t); TH_API void THTensor_(log2)(THTensor *r_, THTensor *t); TH_API void THTensor_(exp)(THTensor *r_, THTensor *t); diff --git a/aten/src/TH/generic/THVector.h b/aten/src/TH/generic/THVector.h index c29d4503de167..d6fffd74d6bdc 100644 --- a/aten/src/TH/generic/THVector.h +++ b/aten/src/TH/generic/THVector.h @@ -31,7 +31,6 @@ TH_API void THVector_(abs)(scalar_t *y, const scalar_t *x, const ptrdiff_t n); /* floating point only now */ #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) -TH_API void THVector_(log10)(scalar_t *y, const scalar_t *x, const ptrdiff_t n); TH_API void THVector_(log1p)(scalar_t *y, const scalar_t *x, const ptrdiff_t n); TH_API void THVector_(log2)(scalar_t *y, const scalar_t *x, const ptrdiff_t n); TH_API void THVector_(sigmoid)(scalar_t *y, const scalar_t *x, const ptrdiff_t n); diff --git a/aten/src/TH/generic/THVectorDefault.cpp b/aten/src/TH/generic/THVectorDefault.cpp index 6814cdc362c6b..e847cda996af4 100644 --- a/aten/src/TH/generic/THVectorDefault.cpp +++ b/aten/src/TH/generic/THVectorDefault.cpp @@ -243,7 +243,6 @@ VECTOR_IMPLEMENT_FUNCTION(abs,) #define TH_MATH_NAME(fn) fn #endif -VECTOR_IMPLEMENT_FUNCTION(log10,TH_MATH_NAME(log10)) VECTOR_IMPLEMENT_FUNCTION(log1p,TH_MATH_NAME(log1p)) VECTOR_IMPLEMENT_FUNCTION(log2,TH_MATH_NAME(log2)) VECTOR_IMPLEMENT_FUNCTION(sigmoid_DEFAULT,TH_MATH_NAME(TH_sigmoid)) diff --git a/aten/src/THC/THCGeneral.h.in b/aten/src/THC/THCGeneral.h.in index 6a086db395086..fac4e04b3f32d 100644 --- a/aten/src/THC/THCGeneral.h.in +++ b/aten/src/THC/THCGeneral.h.in @@ -3,7 +3,6 @@ #include #include -#undef log10 #undef log1p #undef log2 diff --git a/aten/src/THC/THCNumerics.cuh b/aten/src/THC/THCNumerics.cuh index 7e6e932dfa503..b80335958ca4b 100644 --- a/aten/src/THC/THCNumerics.cuh +++ b/aten/src/THC/THCNumerics.cuh @@ -204,7 +204,6 @@ struct THCNumerics { static inline __host__ __device__ at::Half exp(at::Half a) { return std::exp(a); } static inline __host__ __device__ at::Half exp10(at::Half a) { return ::exp10(a); } - static inline __host__ __device__ at::Half log10(at::Half a) { return ::log10(a); } static inline __host__ __device__ at::Half log1p(at::Half a) { return ::log1p(a); } static inline __host__ __device__ at::Half log2(at::Half a) { return ::log2(a); } static inline __host__ __device__ at::Half cos(at::Half a) { return ::cos(a); } @@ -279,7 +278,6 @@ struct THCNumerics { static inline __host__ __device__ float exp (float a) { return expf(a); } static inline __host__ __device__ float exp10(float a) { return exp10f(a); } - static inline __host__ __device__ float log10(float a) { return log10f(a); } static inline __host__ __device__ float log1p(float a) { return log1pf(a); } static inline __host__ __device__ float log2 (float a) { return log2f(a); } static inline __host__ __device__ float cos (float a) { return cosf(a); } @@ -329,7 +327,6 @@ struct THCNumerics { static inline __host__ __device__ double exp (double a) { return ::exp(a); } static inline __host__ __device__ double exp10(double a) { return ::exp10(a); } - static inline __host__ __device__ double log10(double a) { return ::log10(a); } static inline __host__ __device__ double log1p(double a) { return ::log1p(a); } static inline __host__ __device__ double log2 (double a) { return ::log2(a); } static inline __host__ __device__ double cos (double a) { return ::cos(a); } diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu index b3e6a16c1c3f4..8f84d4ee1bf39 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.cu +++ b/aten/src/THC/generic/THCTensorMathPointwise.cu @@ -198,7 +198,6 @@ static void propagate_names_if_named_tensor_enabled(THCTensor* result, THCTensor #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log10, THCNumerics::log10, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log1p, THCNumerics::log1p, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( log2, THCNumerics::log2, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( exp, THCNumerics::exp, Real) diff --git a/aten/src/THC/generic/THCTensorMathPointwise.h b/aten/src/THC/generic/THCTensorMathPointwise.h index 94e0479c1aa02..3f688ef18b794 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.h +++ b/aten/src/THC/generic/THCTensorMathPointwise.h @@ -16,7 +16,6 @@ THC_API void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor * #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THC_API void THCTensor_(sigmoid)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(log10)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(log1p)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(log2)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(exp)(THCState *state, THCTensor *self, THCTensor *src); diff --git a/aten/src/THCUNN/THCHalfAutoNumerics.cuh b/aten/src/THCUNN/THCHalfAutoNumerics.cuh index 7b55f18b77c52..9d971315e49dc 100644 --- a/aten/src/THCUNN/THCHalfAutoNumerics.cuh +++ b/aten/src/THCUNN/THCHalfAutoNumerics.cuh @@ -39,10 +39,6 @@ inline __host__ __device__ THHalf exp(THHalf a) { return THCNumerics::exp(a); } -inline __host__ __device__ THHalf log10(THHalf a) { - return THCNumerics::log10(a); -} - inline __host__ __device__ THHalf log1p(THHalf a) { return THCNumerics::log1p(a); } diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index dd829a8563766..ac90791ccc452 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -130,6 +130,35 @@ bool TensorImpl::compute_strides_like_channels_last() const { return false; } +bool TensorImpl::compute_non_overlapping_and_dense() const { + if (dim() == 1) { + return size(0) < 2 || stride(0) == 1; + } + SmallVector perm; + perm.resize(dim()); + for (int64_t i = 0; i < dim(); i ++) { + perm[i] = i; + } + // Sort by strides, leaving 0 and 1 sized dims at the end of the array + std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) { + if (sizes_[a] < 2) { + return false; + } + return strides_[a] < strides_[b]; + }); + auto require_stride = 1; + for (int64_t i = 0; i < dim(); i ++) { + if (sizes_[perm[i]] < 2) { + return true; + } + if (strides_[perm[i]] != require_stride) { + return false; + } + require_stride *= sizes_[perm[i]]; + } + return true; +} + void TensorImpl::release_resources() { autograd_meta_.reset(); if (storage_) { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 366d2b7cbe6a0..4e8dfe8bfc7bd 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1390,6 +1390,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { is_contiguous_ = false; is_channels_last_contiguous_ = false; is_channels_last_ = false; + is_non_overlapping_and_dense_ = false; switch (memory_format) { case MemoryFormat::Contiguous: { strides_.resize(sizes_.size(), 0); @@ -1401,6 +1402,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } is_contiguous_ = true; + is_non_overlapping_and_dense_ = true; return; } case MemoryFormat::ChannelsLast: { @@ -1410,6 +1412,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { set_sizes_and_strides(sizes(), get_channels_last_strides(sizes())); is_channels_last_contiguous_ = true; is_channels_last_ = true; + is_non_overlapping_and_dense_ = true; return; } case MemoryFormat::Preserve: @@ -1421,6 +1424,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_channels_last_; } + bool is_non_overlapping_and_dense() const { + return is_non_overlapping_and_dense_; + } + private: // The Caffe2 Resize() method supports being called both as Resize({2,2}) as @@ -1501,6 +1508,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool compute_strides_like_channels_last() const; + bool compute_non_overlapping_and_dense() const; + protected: /** * Recompute the cached numel of a tensor. Call this if you modify sizes. @@ -1517,6 +1526,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { is_contiguous_ = compute_contiguous(); is_channels_last_contiguous_ = compute_channels_last_contiguous(); is_channels_last_ = is_channels_last_contiguous_ || compute_strides_like_channels_last(); + is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || compute_non_overlapping_and_dense(); } /** @@ -1546,6 +1556,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId); } dest_impl->is_contiguous_ = src_impl->is_contiguous_; + dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_; + dest_impl->is_channels_last_ = src_impl->is_channels_last_; + dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_; dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_; dest_impl->reserved_ = src_impl->reserved_; dest_impl->set_version_counter(version_counter); @@ -1641,6 +1654,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // contiguous memory block. bool is_channels_last_contiguous_ = false; + // Dense tensor is the tensor that store values in a contiguous block of memory. + // Non-overlapping tensor is the tensor in which elements occupy individual + // non-repetitive memory. + bool is_non_overlapping_and_dense_ = false; + bool is_wrapped_number_ = false; // NOTE [ Metadata Change for a Detached Tensor ] diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index 3514fac0eeae1..e992a4e1bfdae 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -17,7 +17,7 @@ configure_file( # transitively passed on to all libraries dependent on PyTorch. # Note: if you add a new source file/header, you will need to update -# tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py for new files +# torch/utils/hipify/cuda_to_hip_mappings.py for new files # and headers you add set(C10_CUDA_SRCS CUDAStream.cpp diff --git a/c10/cuda/README.md b/c10/cuda/README.md index 1cafbc78fad8d..c65ba8e3b155e 100644 --- a/c10/cuda/README.md +++ b/c10/cuda/README.md @@ -30,7 +30,7 @@ void my_func(); ``` Thus, if you add new functionality to c10, you must also update `C10_MAPPINGS` -`tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py` to transpile +`torch/utils/hipify/cuda_to_hip_mappings.py` to transpile occurrences of `cuda::my_func` to `hip::my_func`. (At the moment, we do NOT have a catch all `cuda::` to `hip::` namespace conversion, as not all `cuda` namespaces are converted to `hip::`, even though diff --git a/c10/util/flat_hash_map.h b/c10/util/flat_hash_map.h index 500336b5ed87d..c513a61ab4899 100644 --- a/c10/util/flat_hash_map.h +++ b/c10/util/flat_hash_map.h @@ -21,6 +21,7 @@ #include #include #include +#include #ifndef _MSC_VER #pragma GCC diagnostic push diff --git a/c10/util/variant.h b/c10/util/variant.h deleted file mode 100644 index b183641fe3415..0000000000000 --- a/c10/util/variant.h +++ /dev/null @@ -1,2822 +0,0 @@ -// MPark.Variant -// -// Copyright Michael Park, 2015-2017 -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) -// -// From https://github.com/mpark/variant -// -// C10 -// - Move to `c10` namespace. -// - Rename namespace `detail` to `detail_`, to not conflict with existing -// c10 implementations in `detail` namespace. -// - `struct in_place_t` is renamed to `struct variant_in_place_t`, to not -// conflict with `struct in_place_t` in c10/util/Optional.h. - -#ifndef C10_UTIL_VARIANT_H_ -#define C10_UTIL_VARIANT_H_ - -/* - variant synopsis - -namespace std { - - // 20.7.2, class template variant - template - class variant { - public: - - // 20.7.2.1, constructors - constexpr variant() noexcept(see below); - variant(const variant&); - variant(variant&&) noexcept(see below); - - template constexpr variant(T&&) noexcept(see below); - - template - constexpr explicit variant(in_place_type_t, Args&&...); - - template - constexpr explicit variant( - in_place_type_t, initializer_list, Args&&...); - - template - constexpr explicit variant(in_place_index_t, Args&&...); - - template - constexpr explicit variant( - in_place_index_t, initializer_list, Args&&...); - - // 20.7.2.2, destructor - ~variant(); - - // 20.7.2.3, assignment - variant& operator=(const variant&); - variant& operator=(variant&&) noexcept(see below); - - template variant& operator=(T&&) noexcept(see below); - - // 20.7.2.4, modifiers - template - T& emplace(Args&&...); - - template - T& emplace(initializer_list, Args&&...); - - template - variant_alternative& emplace(Args&&...); - - template - variant_alternative& emplace(initializer_list, Args&&...); - - // 20.7.2.5, value status - constexpr bool valueless_by_exception() const noexcept; - constexpr size_t index() const noexcept; - - // 20.7.2.6, swap - void swap(variant&) noexcept(see below); - }; - - // 20.7.3, variant helper classes - template struct variant_size; // undefined - - template - constexpr size_t variant_size_v = variant_size::value; - - template struct variant_size; - template struct variant_size; - template struct variant_size; - - template - struct variant_size>; - - template struct variant_alternative; // undefined - - template - using variant_alternative_t = typename variant_alternative::type; - - template struct variant_alternative; - template struct variant_alternative; - template struct variant_alternative; - - template - struct variant_alternative>; - - constexpr size_t variant_npos = -1; - - // 20.7.4, value access - template - constexpr bool holds_alternative(const variant&) noexcept; - - template - constexpr variant_alternative_t>& - get(variant&); - - template - constexpr variant_alternative_t>&& - get(variant&&); - - template - constexpr variant_alternative_t> const& - get(const variant&); - - template - constexpr variant_alternative_t> const&& - get(const variant&&); - - template - constexpr T& get(variant&); - - template - constexpr T&& get(variant&&); - - template - constexpr const T& get(const variant&); - - template - constexpr const T&& get(const variant&&); - - template - constexpr add_pointer_t>> - get_if(variant*) noexcept; - - template - constexpr add_pointer_t>> - get_if(const variant*) noexcept; - - template - constexpr add_pointer_t - get_if(variant*) noexcept; - - template - constexpr add_pointer_t - get_if(const variant*) noexcept; - - // 20.7.5, relational operators - template - constexpr bool operator==(const variant&, const variant&); - - template - constexpr bool operator!=(const variant&, const variant&); - - template - constexpr bool operator<(const variant&, const variant&); - - template - constexpr bool operator>(const variant&, const variant&); - - template - constexpr bool operator<=(const variant&, const variant&); - - template - constexpr bool operator>=(const variant&, const variant&); - - // 20.7.6, visitation - template - constexpr see below visit(Visitor&&, Variants&&...); - - // 20.7.7, class monostate - struct monostate; - - // 20.7.8, monostate relational operators - constexpr bool operator<(monostate, monostate) noexcept; - constexpr bool operator>(monostate, monostate) noexcept; - constexpr bool operator<=(monostate, monostate) noexcept; - constexpr bool operator>=(monostate, monostate) noexcept; - constexpr bool operator==(monostate, monostate) noexcept; - constexpr bool operator!=(monostate, monostate) noexcept; - - // 20.7.9, specialized algorithms - template - void swap(variant&, variant&) noexcept(see below); - - // 20.7.10, class bad_variant_access - class bad_variant_access; - - // 20.7.11, hash support - template struct hash; - template struct hash>; - template <> struct hash; - -} // namespace std - -*/ - -#include -#include -#include -#include -#include -#include -#include - -// MPark.Variant -// -// Copyright Michael Park, 2015-2017 -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) - -#ifndef MPARK_CONFIG_HPP -#define MPARK_CONFIG_HPP - -// MSVC 2015 Update 3. -#if __cplusplus < 201103L && (!defined(_MSC_VER) || _MSC_FULL_VER < 190024210) -#error "MPark.Variant requires C++11 support." -#endif - -#ifndef __has_attribute -#define __has_attribute(x) 0 -#endif - -#ifndef __has_builtin -#define __has_builtin(x) 0 -#endif - -#ifndef __has_include -#define __has_include(x) 0 -#endif - -#ifndef __has_feature -#define __has_feature(x) 0 -#endif - -#if __has_attribute(always_inline) || defined(__GNUC__) -#define MPARK_ALWAYS_INLINE __attribute__((__always_inline__)) inline -#elif defined(_MSC_VER) -#define MPARK_ALWAYS_INLINE __forceinline -#else -#define MPARK_ALWAYS_INLINE inline -#endif - -#if __has_builtin(__builtin_addressof) || \ - (defined(__GNUC__) && __GNUC__ >= 7) || defined(_MSC_VER) -#define MPARK_BUILTIN_ADDRESSOF -#endif - -#if __has_builtin(__builtin_unreachable) || defined(__GNUC__) -#define MPARK_BUILTIN_UNREACHABLE __builtin_unreachable() -#elif defined(_MSC_VER) -#define MPARK_BUILTIN_UNREACHABLE __assume(false) -#else -#define MPARK_BUILTIN_UNREACHABLE -#endif - -#if __has_builtin(__type_pack_element) -#define MPARK_TYPE_PACK_ELEMENT -#endif - -#if defined(__cpp_constexpr) && __cpp_constexpr >= 200704 && \ - !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 9) -#define MPARK_CPP11_CONSTEXPR -#endif - -#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304 -#define MPARK_CPP14_CONSTEXPR -#endif - -#if __has_feature(cxx_exceptions) || defined(__cpp_exceptions) || \ - (defined(_MSC_VER) && defined(_CPPUNWIND)) -#define MPARK_EXCEPTIONS -#endif - -#if defined(__cpp_generic_lambdas) || defined(_MSC_VER) -#define MPARK_GENERIC_LAMBDAS -#endif - -#if defined(__cpp_lib_integer_sequence) -#define MPARK_INTEGER_SEQUENCE -#endif - -#if defined(__cpp_return_type_deduction) || defined(_MSC_VER) -#define MPARK_RETURN_TYPE_DEDUCTION -#endif - -#if defined(__cpp_lib_transparent_operators) || defined(_MSC_VER) -#define MPARK_TRANSPARENT_OPERATORS -#endif - -#if defined(__cpp_variable_templates) || defined(_MSC_VER) -#define MPARK_VARIABLE_TEMPLATES -#endif - -#if !defined(__GLIBCXX__) || __has_include() // >= libstdc++-5 -#define MPARK_TRIVIALITY_TYPE_TRAITS -#define MPARK_INCOMPLETE_TYPE_TRAITS -#endif - -#endif // MPARK_CONFIG_HPP - -// MPark.Variant -// -// Copyright Michael Park, 2015-2017 -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) - -#ifndef MPARK_IN_PLACE_HPP -#define MPARK_IN_PLACE_HPP - -#include - - -namespace c10 { - - struct variant_in_place_t { explicit variant_in_place_t() = default; }; - - template - struct in_place_index_t { explicit in_place_index_t() = default; }; - - template - struct in_place_type_t { explicit in_place_type_t() = default; }; - -#ifdef MPARK_VARIABLE_TEMPLATES - constexpr variant_in_place_t in_place{}; - - template constexpr in_place_index_t in_place_index{}; - - template constexpr in_place_type_t in_place_type{}; -#endif - -} // namespace c10 - -#endif // MPARK_IN_PLACE_HPP - -// MPark.Variant -// -// Copyright Michael Park, 2015-2017 -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.md or copy at http://boost.org/LICENSE_1_0.txt) - -#ifndef MPARK_LIB_HPP -#define MPARK_LIB_HPP - -#include -#include -#include -#include - - -#define MPARK_RETURN(...) \ - noexcept(noexcept(__VA_ARGS__)) -> decltype(__VA_ARGS__) { return __VA_ARGS__; } - -namespace c10 { - namespace lib { - template - struct identity { using type = T; }; - - inline namespace cpp14 { - template - struct array { - constexpr const T &operator[](std::size_t index) const { - return data[index]; - } - - T data[N == 0 ? 1 : N]; - }; - - template - using add_pointer_t = typename std::add_pointer::type; - - template - using common_type_t = typename std::common_type::type; - - template - using decay_t = typename std::decay::type; - - template - using enable_if_t = typename std::enable_if::type; - - template - using remove_const_t = typename std::remove_const::type; - - template - using remove_reference_t = typename std::remove_reference::type; - - template - inline constexpr T &&forward(remove_reference_t &t) noexcept { - return static_cast(t); - } - - template - inline constexpr T &&forward(remove_reference_t &&t) noexcept { - static_assert(!std::is_lvalue_reference::value, - "can not forward an rvalue as an lvalue"); - return static_cast(t); - } - - template - inline constexpr remove_reference_t &&move(T &&t) noexcept { - return static_cast &&>(t); - } - -#ifdef MPARK_INTEGER_SEQUENCE - using std::integer_sequence; - using std::index_sequence; - using std::make_index_sequence; - using std::index_sequence_for; -#else - template - struct integer_sequence { - using value_type = T; - static constexpr std::size_t size() noexcept { return sizeof...(Is); } - }; - - template - using index_sequence = integer_sequence; - - template - struct make_index_sequence_concat; - - template - struct make_index_sequence_concat, - index_sequence> - : identity> {}; - - template - struct make_index_sequence_impl; - - template - using make_index_sequence = typename make_index_sequence_impl::type; - - template - struct make_index_sequence_impl - : make_index_sequence_concat, - make_index_sequence> {}; - - template <> - struct make_index_sequence_impl<0> : identity> {}; - - template <> - struct make_index_sequence_impl<1> : identity> {}; - - template - using index_sequence_for = make_index_sequence; -#endif - - // -#ifdef MPARK_TRANSPARENT_OPERATORS - using equal_to = std::equal_to<>; -#else - struct equal_to { - template - inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const - MPARK_RETURN(lib::forward(lhs) == lib::forward(rhs)) - }; -#endif - -#ifdef MPARK_TRANSPARENT_OPERATORS - using not_equal_to = std::not_equal_to<>; -#else - struct not_equal_to { - template - inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const - MPARK_RETURN(lib::forward(lhs) != lib::forward(rhs)) - }; -#endif - -#ifdef MPARK_TRANSPARENT_OPERATORS - using less = std::less<>; -#else - struct less { - template - inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const - MPARK_RETURN(lib::forward(lhs) < lib::forward(rhs)) - }; -#endif - -#ifdef MPARK_TRANSPARENT_OPERATORS - using greater = std::greater<>; -#else - struct greater { - template - inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const - MPARK_RETURN(lib::forward(lhs) > lib::forward(rhs)) - }; -#endif - -#ifdef MPARK_TRANSPARENT_OPERATORS - using less_equal = std::less_equal<>; -#else - struct less_equal { - template - inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const - MPARK_RETURN(lib::forward(lhs) <= lib::forward(rhs)) - }; -#endif - -#ifdef MPARK_TRANSPARENT_OPERATORS - using greater_equal = std::greater_equal<>; -#else - struct greater_equal { - template - inline constexpr auto operator()(Lhs &&lhs, Rhs &&rhs) const - MPARK_RETURN(lib::forward(lhs) >= lib::forward(rhs)) - }; -#endif - } // namespace cpp14 - - inline namespace cpp17 { - - // - template - using bool_constant = std::integral_constant; - - template - struct voider : identity {}; - - template - using void_t = typename voider::type; - - namespace detail_ { - namespace swappable { - - using std::swap; - - template - struct is_swappable { - private: - template (), - std::declval()))> - inline static std::true_type test(int); - - template - inline static std::false_type test(...); - - public: - static constexpr bool value = decltype(test(0))::value; - }; - - template - struct is_nothrow_swappable { - static constexpr bool value = - noexcept(swap(std::declval(), std::declval())); - }; - - template - struct is_nothrow_swappable : std::false_type {}; - - } // namespace swappable - } // namespace detail_ - - using detail_::swappable::is_swappable; - - template - using is_nothrow_swappable = - detail_::swappable::is_nothrow_swappable::value, T>; - - // - namespace detail_ { - - template - struct is_reference_wrapper : std::false_type {}; - - template - struct is_reference_wrapper> - : std::true_type {}; - - template - struct Invoke; - - template <> - struct Invoke { - template - inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) - MPARK_RETURN((lib::forward(arg).*pmf)(lib::forward(args)...)) - }; - - template <> - struct Invoke { - template - inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) - MPARK_RETURN((lib::forward(arg).get().*pmf)(lib::forward(args)...)) - }; - - template <> - struct Invoke { - template - inline static constexpr auto invoke(R T::*pmf, Arg &&arg, Args &&... args) - MPARK_RETURN(((*lib::forward(arg)).*pmf)(lib::forward(args)...)) - }; - - template <> - struct Invoke { - template - inline static constexpr auto invoke(R T::*pmo, Arg &&arg) - MPARK_RETURN(lib::forward(arg).*pmo) - }; - - template <> - struct Invoke { - template - inline static constexpr auto invoke(R T::*pmo, Arg &&arg) - MPARK_RETURN(lib::forward(arg).get().*pmo) - }; - - template <> - struct Invoke { - template - inline static constexpr auto invoke(R T::*pmo, Arg &&arg) - MPARK_RETURN((*lib::forward(arg)).*pmo) - }; - - template - inline constexpr auto invoke(R T::*f, Arg &&arg, Args &&... args) - MPARK_RETURN( - Invoke::value, - (std::is_base_of>::value - ? 0 - : is_reference_wrapper>::value - ? 1 - : 2)>::invoke(f, - lib::forward(arg), - lib::forward(args)...)) - -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4100) -#endif - template - inline constexpr auto invoke(F &&f, Args &&... args) - MPARK_RETURN(lib::forward(f)(lib::forward(args)...)) -#ifdef _MSC_VER -#pragma warning(pop) -#endif - } // namespace detail_ - - template - inline constexpr auto invoke(F &&f, Args &&... args) - MPARK_RETURN(detail_::invoke(lib::forward(f), - lib::forward(args)...)) - - namespace detail_ { - - template - struct invoke_result {}; - - template - struct invoke_result(), std::declval()...))>, - F, - Args...> - : identity(), std::declval()...))> {}; - - } // namespace detail_ - - template - using invoke_result = detail_::invoke_result; - - template - using invoke_result_t = typename invoke_result::type; - - namespace detail_ { - - template - struct is_invocable : std::false_type {}; - - template - struct is_invocable>, F, Args...> - : std::true_type {}; - - template - struct is_invocable_r : std::false_type {}; - - template - struct is_invocable_r>, - R, - F, - Args...> - : std::is_convertible, R> {}; - - } // namespace detail_ - - template - using is_invocable = detail_::is_invocable; - - template - using is_invocable_r = detail_::is_invocable_r; - - namespace detail_ { - - template - struct is_nothrow_invocable { - static constexpr bool value = - noexcept(lib::invoke(std::declval(), std::declval()...)); - }; - - template - struct is_nothrow_invocable : std::false_type {}; - - template - struct is_nothrow_invocable_r { - private: - inline static R impl() { - return lib::invoke(std::declval(), std::declval()...); - } - - public: - static constexpr bool value = noexcept(impl()); - }; - - template - struct is_nothrow_invocable_r : std::false_type {}; - - } // namespace detail_ - - template - using is_nothrow_invocable = detail_:: - is_nothrow_invocable::value, F, Args...>; - - template - using is_nothrow_invocable_r = - detail_::is_nothrow_invocable_r::value, - R, - F, - Args...>; - - // -#ifdef MPARK_BUILTIN_ADDRESSOF - template - inline constexpr T *addressof(T &arg) noexcept { - return __builtin_addressof(arg); - } -#else - namespace detail_ { - - namespace has_addressof_impl { - - struct fail; - - template - inline fail operator&(T &&); - - template - inline static constexpr bool impl() { - return (std::is_class::value || std::is_union::value) && - !std::is_same()), fail>::value; - } - - } // namespace has_addressof_impl - - template - using has_addressof = bool_constant()>; - - template - inline constexpr T *addressof(T &arg, std::true_type) noexcept { - return std::addressof(arg); - } - - template - inline constexpr T *addressof(T &arg, std::false_type) noexcept { - return &arg; - } - - } // namespace detail_ - - template - inline constexpr T *addressof(T &arg) noexcept { - return detail_::addressof(arg, detail_::has_addressof{}); - } -#endif - - template - inline constexpr T *addressof(const T &&) = delete; - - } // namespace cpp17 - - template - struct remove_all_extents : identity {}; - - template - struct remove_all_extents> : remove_all_extents {}; - - template - using remove_all_extents_t = typename remove_all_extents::type; - - template - using size_constant = std::integral_constant; - - template - struct indexed_type : size_constant { using type = T; }; - - template - using all = std::is_same, - integer_sequence>; - -#ifdef MPARK_TYPE_PACK_ELEMENT - template - using type_pack_element_t = __type_pack_element; -#else - template - struct type_pack_element_impl { - private: - template - struct set; - - template - struct set> : indexed_type... {}; - - template - inline static std::enable_if impl(indexed_type); - - inline static std::enable_if impl(...); - - public: - using type = decltype(impl(set>{})); - }; - - template - using type_pack_element = typename type_pack_element_impl::type; - - template - using type_pack_element_t = typename type_pack_element::type; -#endif - -#ifdef MPARK_TRIVIALITY_TYPE_TRAITS - using std::is_trivially_copy_constructible; - using std::is_trivially_move_constructible; - using std::is_trivially_copy_assignable; - using std::is_trivially_move_assignable; -#else - template - struct is_trivially_copy_constructible - : bool_constant< - std::is_copy_constructible::value && __has_trivial_copy(T)> {}; - - template - struct is_trivially_move_constructible : bool_constant<__is_trivial(T)> {}; - - template - struct is_trivially_copy_assignable - : bool_constant< - std::is_copy_assignable::value && __has_trivial_assign(T)> {}; - - template - struct is_trivially_move_assignable : bool_constant<__is_trivial(T)> {}; -#endif - - template - struct dependent_type : T {}; - - template - struct push_back; - - template - using push_back_t = typename push_back::type; - - template - struct push_back, J> { - using type = index_sequence; - }; - - } // namespace lib -} // namespace c10 - -#undef MPARK_RETURN - -#endif // MPARK_LIB_HPP - - -namespace c10 { - -#ifdef MPARK_RETURN_TYPE_DEDUCTION - -#define AUTO auto -#define AUTO_RETURN(...) { return __VA_ARGS__; } - -#define AUTO_REFREF auto && -#define AUTO_REFREF_RETURN(...) { return __VA_ARGS__; } - -#define DECLTYPE_AUTO decltype(auto) -#define DECLTYPE_AUTO_RETURN(...) { return __VA_ARGS__; } - -#else - -#define AUTO auto -#define AUTO_RETURN(...) \ - -> lib::decay_t { return __VA_ARGS__; } - -#define AUTO_REFREF auto -#define AUTO_REFREF_RETURN(...) \ - -> decltype((__VA_ARGS__)) { \ - static_assert(std::is_reference::value, ""); \ - return __VA_ARGS__; \ - } - -#define DECLTYPE_AUTO auto -#define DECLTYPE_AUTO_RETURN(...) \ - -> decltype(__VA_ARGS__) { return __VA_ARGS__; } - -#endif - - class bad_variant_access : public std::exception { - public: - virtual const char *what() const noexcept override { return "bad_variant_access"; } - }; - - [[noreturn]] inline void throw_bad_variant_access() { -#ifdef MPARK_EXCEPTIONS - throw bad_variant_access{}; -#else - std::terminate(); - MPARK_BUILTIN_UNREACHABLE; -#endif - } - - template - class variant; - - template - struct variant_size; - -#ifdef MPARK_VARIABLE_TEMPLATES - template - constexpr std::size_t variant_size_v = variant_size::value; -#endif - - template - struct variant_size : variant_size {}; - - template - struct variant_size : variant_size {}; - - template - struct variant_size : variant_size {}; - - template - struct variant_size> : lib::size_constant {}; - - template - struct variant_alternative; - - template - using variant_alternative_t = typename variant_alternative::type; - - template - struct variant_alternative - : std::add_const> {}; - - template - struct variant_alternative - : std::add_volatile> {}; - - template - struct variant_alternative - : std::add_cv> {}; - - template - struct variant_alternative> { - static_assert(I < sizeof...(Ts), - "index out of bounds in `std::variant_alternative<>`"); - using type = lib::type_pack_element_t; - }; - - constexpr std::size_t variant_npos = static_cast(-1); - - namespace detail_ { - - constexpr std::size_t not_found = static_cast(-1); - constexpr std::size_t ambiguous = static_cast(-2); - -#ifdef MPARK_CPP14_CONSTEXPR - template - inline constexpr std::size_t find_index() { - constexpr lib::array matches = { - {std::is_same::value...} - }; - std::size_t result = not_found; - for (std::size_t i = 0; i < sizeof...(Ts); ++i) { - if (matches[i]) { - if (result != not_found) { - return ambiguous; - } - result = i; - } - } - return result; - } -#else - inline constexpr std::size_t find_index_impl(std::size_t result, - std::size_t) { - return result; - } - - template - inline constexpr std::size_t find_index_impl(std::size_t result, - std::size_t idx, - bool b, - Bs... bs) { - return b ? (result != not_found ? ambiguous - : find_index_impl(idx, idx + 1, bs...)) - : find_index_impl(result, idx + 1, bs...); - } - - template - inline constexpr std::size_t find_index() { - return find_index_impl(not_found, 0, std::is_same::value...); - } -#endif - - template - using find_index_sfinae_impl = - lib::enable_if_t>; - - template - using find_index_sfinae = find_index_sfinae_impl()>; - - template - struct find_index_checked_impl : lib::size_constant { - static_assert(I != not_found, "the specified type is not found."); - static_assert(I != ambiguous, "the specified type is ambiguous."); - }; - - template - using find_index_checked = find_index_checked_impl()>; - - struct valueless_t {}; - - enum class Trait { TriviallyAvailable, Available, Unavailable }; - - template class IsTriviallyAvailable, - template class IsAvailable> - inline constexpr Trait trait() { - return IsTriviallyAvailable::value - ? Trait::TriviallyAvailable - : IsAvailable::value ? Trait::Available - : Trait::Unavailable; - } - -#ifdef MPARK_CPP14_CONSTEXPR - template - inline constexpr Trait common_trait(Traits... traits_) { - Trait result = Trait::TriviallyAvailable; - lib::array traits = {{traits_...}}; - for (std::size_t i = 0; i < sizeof...(Traits); ++i) { - Trait t = traits[i]; - if (static_cast(t) > static_cast(result)) { - result = t; - } - } - return result; - } -#else - inline constexpr Trait common_trait_impl(Trait result) { return result; } - - template - inline constexpr Trait common_trait_impl(Trait result, - Trait t, - Traits... ts) { - return static_cast(t) > static_cast(result) - ? common_trait_impl(t, ts...) - : common_trait_impl(result, ts...); - } - - template - inline constexpr Trait common_trait(Traits... ts) { - return common_trait_impl(Trait::TriviallyAvailable, ts...); - } -#endif - - template - struct traits { - static constexpr Trait copy_constructible_trait = - common_trait(trait()...); - - static constexpr Trait move_constructible_trait = - common_trait(trait()...); - - static constexpr Trait copy_assignable_trait = - common_trait(copy_constructible_trait, - trait()...); - - static constexpr Trait move_assignable_trait = - common_trait(move_constructible_trait, - trait()...); - - static constexpr Trait destructible_trait = - common_trait(trait()...); - }; - - namespace access { - - struct recursive_union { -#ifdef MPARK_RETURN_TYPE_DEDUCTION - template - inline static constexpr auto &&get_alt(V &&v, in_place_index_t<0>) { - return lib::forward(v).head_; - } - - template - inline static constexpr auto &&get_alt(V &&v, in_place_index_t) { - return get_alt(lib::forward(v).tail_, in_place_index_t{}); - } -#else - template - struct get_alt_impl { - template - inline constexpr AUTO_REFREF operator()(V &&v) const - AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v).tail_)) - }; - - template - struct get_alt_impl<0, Dummy> { - template - inline constexpr AUTO_REFREF operator()(V &&v) const - AUTO_REFREF_RETURN(lib::forward(v).head_) - }; - - template - inline static constexpr AUTO_REFREF get_alt(V &&v, in_place_index_t) - AUTO_REFREF_RETURN(get_alt_impl{}(lib::forward(v))) -#endif - }; - - struct base { - template - inline static constexpr AUTO_REFREF get_alt(V &&v) -#ifdef _MSC_VER - AUTO_REFREF_RETURN(recursive_union::get_alt( - lib::forward(v).data_, in_place_index_t{})) -#else - AUTO_REFREF_RETURN(recursive_union::get_alt( - data(lib::forward(v)), in_place_index_t{})) -#endif - }; - - struct variant { - template - inline static constexpr AUTO_REFREF get_alt(V &&v) - AUTO_REFREF_RETURN(base::get_alt(lib::forward(v).impl_)) - }; - - } // namespace access - - namespace visitation { - -#if defined(MPARK_CPP14_CONSTEXPR) && !defined(_MSC_VER) -#define MPARK_VARIANT_SWITCH_VISIT -#endif - - struct base { - template - using dispatch_result_t = decltype( - lib::invoke(std::declval(), - access::base::get_alt<0>(std::declval())...)); - - template - struct expected { - template - inline static constexpr bool but_got() { - return std::is_same::value; - } - }; - - template - struct visit_return_type_check { - static_assert( - expected::template but_got(), - "`visit` requires the visitor to have a single return type"); - - template - inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, - Alts &&... alts) - DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), - lib::forward(alts)...)) - }; - -#ifdef MPARK_VARIANT_SWITCH_VISIT - template - struct dispatcher; - - template - struct dispatcher { - template - MPARK_ALWAYS_INLINE static constexpr R dispatch( - F &&, typename ITs::type &&..., Vs &&...) { - MPARK_BUILTIN_UNREACHABLE; - } - - template - MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&, Vs &&...) { - MPARK_BUILTIN_UNREACHABLE; - } - - template - MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t, - F &&, - Vs &&...) { - MPARK_BUILTIN_UNREACHABLE; - } - }; - - template - struct dispatcher { - template - MPARK_ALWAYS_INLINE static constexpr R dispatch( - F &&f, typename ITs::type &&... visited_vs) { - using Expected = R; - using Actual = decltype(lib::invoke( - lib::forward(f), - access::base::get_alt( - lib::forward(visited_vs))...)); - return visit_return_type_check::invoke( - lib::forward(f), - access::base::get_alt( - lib::forward(visited_vs))...); - } - - template - MPARK_ALWAYS_INLINE static constexpr R dispatch( - F &&f, typename ITs::type &&... visited_vs, V &&v, Vs &&... vs) { -#define MPARK_DISPATCH(I) \ - dispatcher<(I < lib::decay_t::size()), \ - R, \ - ITs..., \ - lib::indexed_type>:: \ - template dispatch<0>(lib::forward(f), \ - lib::forward(visited_vs)..., \ - lib::forward(v), \ - lib::forward(vs)...) - -#define MPARK_DEFAULT(I) \ - dispatcher<(I < lib::decay_t::size()), R, ITs...>::template dispatch( \ - lib::forward(f), \ - lib::forward(visited_vs)..., \ - lib::forward(v), \ - lib::forward(vs)...) - - switch (v.index()) { - case B + 0: return MPARK_DISPATCH(B + 0); - case B + 1: return MPARK_DISPATCH(B + 1); - case B + 2: return MPARK_DISPATCH(B + 2); - case B + 3: return MPARK_DISPATCH(B + 3); - case B + 4: return MPARK_DISPATCH(B + 4); - case B + 5: return MPARK_DISPATCH(B + 5); - case B + 6: return MPARK_DISPATCH(B + 6); - case B + 7: return MPARK_DISPATCH(B + 7); - case B + 8: return MPARK_DISPATCH(B + 8); - case B + 9: return MPARK_DISPATCH(B + 9); - case B + 10: return MPARK_DISPATCH(B + 10); - case B + 11: return MPARK_DISPATCH(B + 11); - case B + 12: return MPARK_DISPATCH(B + 12); - case B + 13: return MPARK_DISPATCH(B + 13); - case B + 14: return MPARK_DISPATCH(B + 14); - case B + 15: return MPARK_DISPATCH(B + 15); - case B + 16: return MPARK_DISPATCH(B + 16); - case B + 17: return MPARK_DISPATCH(B + 17); - case B + 18: return MPARK_DISPATCH(B + 18); - case B + 19: return MPARK_DISPATCH(B + 19); - case B + 20: return MPARK_DISPATCH(B + 20); - case B + 21: return MPARK_DISPATCH(B + 21); - case B + 22: return MPARK_DISPATCH(B + 22); - case B + 23: return MPARK_DISPATCH(B + 23); - case B + 24: return MPARK_DISPATCH(B + 24); - case B + 25: return MPARK_DISPATCH(B + 25); - case B + 26: return MPARK_DISPATCH(B + 26); - case B + 27: return MPARK_DISPATCH(B + 27); - case B + 28: return MPARK_DISPATCH(B + 28); - case B + 29: return MPARK_DISPATCH(B + 29); - case B + 30: return MPARK_DISPATCH(B + 30); - case B + 31: return MPARK_DISPATCH(B + 31); - default: return MPARK_DEFAULT(B + 32); - } - -#undef MPARK_DEFAULT -#undef MPARK_DISPATCH - } - - template - MPARK_ALWAYS_INLINE static constexpr R dispatch_case(F &&f, - Vs &&... vs) { - using Expected = R; - using Actual = decltype( - lib::invoke(lib::forward(f), - access::base::get_alt(lib::forward(vs))...)); - return visit_return_type_check::invoke( - lib::forward(f), - access::base::get_alt(lib::forward(vs))...); - } - - template - MPARK_ALWAYS_INLINE static constexpr R dispatch_at(std::size_t index, - F &&f, - V &&v, - Vs &&... vs) { - static_assert(lib::all<(lib::decay_t::size() == - lib::decay_t::size())...>::value, - "all of the variants must be the same size."); -#define MPARK_DISPATCH_AT(I) \ - dispatcher<(I < lib::decay_t::size()), R>::template dispatch_case( \ - lib::forward(f), lib::forward(v), lib::forward(vs)...) - -#define MPARK_DEFAULT(I) \ - dispatcher<(I < lib::decay_t::size()), R>::template dispatch_at( \ - index, lib::forward(f), lib::forward(v), lib::forward(vs)...) - - switch (index) { - case B + 0: return MPARK_DISPATCH_AT(B + 0); - case B + 1: return MPARK_DISPATCH_AT(B + 1); - case B + 2: return MPARK_DISPATCH_AT(B + 2); - case B + 3: return MPARK_DISPATCH_AT(B + 3); - case B + 4: return MPARK_DISPATCH_AT(B + 4); - case B + 5: return MPARK_DISPATCH_AT(B + 5); - case B + 6: return MPARK_DISPATCH_AT(B + 6); - case B + 7: return MPARK_DISPATCH_AT(B + 7); - case B + 8: return MPARK_DISPATCH_AT(B + 8); - case B + 9: return MPARK_DISPATCH_AT(B + 9); - case B + 10: return MPARK_DISPATCH_AT(B + 10); - case B + 11: return MPARK_DISPATCH_AT(B + 11); - case B + 12: return MPARK_DISPATCH_AT(B + 12); - case B + 13: return MPARK_DISPATCH_AT(B + 13); - case B + 14: return MPARK_DISPATCH_AT(B + 14); - case B + 15: return MPARK_DISPATCH_AT(B + 15); - case B + 16: return MPARK_DISPATCH_AT(B + 16); - case B + 17: return MPARK_DISPATCH_AT(B + 17); - case B + 18: return MPARK_DISPATCH_AT(B + 18); - case B + 19: return MPARK_DISPATCH_AT(B + 19); - case B + 20: return MPARK_DISPATCH_AT(B + 20); - case B + 21: return MPARK_DISPATCH_AT(B + 21); - case B + 22: return MPARK_DISPATCH_AT(B + 22); - case B + 23: return MPARK_DISPATCH_AT(B + 23); - case B + 24: return MPARK_DISPATCH_AT(B + 24); - case B + 25: return MPARK_DISPATCH_AT(B + 25); - case B + 26: return MPARK_DISPATCH_AT(B + 26); - case B + 27: return MPARK_DISPATCH_AT(B + 27); - case B + 28: return MPARK_DISPATCH_AT(B + 28); - case B + 29: return MPARK_DISPATCH_AT(B + 29); - case B + 30: return MPARK_DISPATCH_AT(B + 30); - case B + 31: return MPARK_DISPATCH_AT(B + 31); - default: return MPARK_DEFAULT(B + 32); - } - -#undef MPARK_DEFAULT -#undef MPARK_DISPATCH_AT - } - }; -#else - template - inline static constexpr const T &at(const T &elem) noexcept { - return elem; - } - - template - inline static constexpr const lib::remove_all_extents_t &at( - const lib::array &elems, std::size_t i, Is... is) noexcept { - return at(elems[i], is...); - } - - template - inline static constexpr lib::array, sizeof...(Fs) + 1> - make_farray(F &&f, Fs &&... fs) { - return {{lib::forward(f), lib::forward(fs)...}}; - } - - template - struct make_fmatrix_impl { - - template - inline static constexpr dispatch_result_t dispatch( - F &&f, Vs &&... vs) { - using Expected = dispatch_result_t; - using Actual = decltype(lib::invoke( - lib::forward(f), - access::base::get_alt(lib::forward(vs))...)); - return visit_return_type_check::invoke( - lib::forward(f), - access::base::get_alt(lib::forward(vs))...); - } - -#ifdef MPARK_RETURN_TYPE_DEDUCTION - template - inline static constexpr auto impl(lib::index_sequence) { - return &dispatch; - } - - template - inline static constexpr auto impl(Is, - lib::index_sequence, - Ls... ls) { - return make_farray(impl(lib::push_back_t{}, ls...)...); - } -#else - template - struct impl; - - template - struct impl> { - inline constexpr AUTO operator()() const - AUTO_RETURN(&dispatch) - }; - - template - struct impl, Ls...> { - inline constexpr AUTO operator()() const - AUTO_RETURN( - make_farray(impl, Ls...>{}()...)) - }; -#endif - }; - -#ifdef MPARK_RETURN_TYPE_DEDUCTION - template - inline static constexpr auto make_fmatrix() { - return make_fmatrix_impl::impl( - lib::index_sequence<>{}, - lib::make_index_sequence::size()>{}...); - } -#else - template - inline static constexpr AUTO make_fmatrix() - AUTO_RETURN( - typename make_fmatrix_impl::template impl< - lib::index_sequence<>, - lib::make_index_sequence::size()>...>{}()) -#endif - - template - struct make_fdiagonal_impl { - template - inline static constexpr dispatch_result_t dispatch( - F &&f, Vs &&... vs) { - using Expected = dispatch_result_t; - using Actual = decltype( - lib::invoke(lib::forward(f), - access::base::get_alt(lib::forward(vs))...)); - return visit_return_type_check::invoke( - lib::forward(f), - access::base::get_alt(lib::forward(vs))...); - } - - template - inline static constexpr AUTO impl(lib::index_sequence) - AUTO_RETURN(make_farray(&dispatch...)) - }; - - template - inline static constexpr auto make_fdiagonal() - -> decltype(make_fdiagonal_impl::impl( - lib::make_index_sequence::size()>{})) { - static_assert(lib::all<(lib::decay_t::size() == - lib::decay_t::size())...>::value, - "all of the variants must be the same size."); - return make_fdiagonal_impl::impl( - lib::make_index_sequence::size()>{}); - } -#endif - }; - -#if !defined(MPARK_VARIANT_SWITCH_VISIT) && \ - (!defined(_MSC_VER) || _MSC_VER >= 1910) - template - using fmatrix_t = decltype(base::make_fmatrix()); - - template - struct fmatrix { - static constexpr fmatrix_t value = - base::make_fmatrix(); - }; - - template - constexpr fmatrix_t fmatrix::value; - - template - using fdiagonal_t = decltype(base::make_fdiagonal()); - - template - struct fdiagonal { - static constexpr fdiagonal_t value = - base::make_fdiagonal(); - }; - - template - constexpr fdiagonal_t fdiagonal::value; -#endif - - struct alt { - template - inline static constexpr DECLTYPE_AUTO visit_alt(Visitor &&visitor, - Vs &&... vs) -#ifdef MPARK_VARIANT_SWITCH_VISIT - DECLTYPE_AUTO_RETURN( - base::dispatcher< - true, - base::dispatch_result_t(vs)))...>>:: - template dispatch<0>(lib::forward(visitor), - as_base(lib::forward(vs))...)) -#elif !defined(_MSC_VER) || _MSC_VER >= 1910 - DECLTYPE_AUTO_RETURN(base::at( - fmatrix(vs)))...>::value, - vs.index()...)(lib::forward(visitor), - as_base(lib::forward(vs))...)) -#else - DECLTYPE_AUTO_RETURN(base::at( - base::make_fmatrix(vs)))...>(), - vs.index()...)(lib::forward(visitor), - as_base(lib::forward(vs))...)) -#endif - - template - inline static constexpr DECLTYPE_AUTO visit_alt_at(std::size_t index, - Visitor &&visitor, - Vs &&... vs) -#ifdef MPARK_VARIANT_SWITCH_VISIT - DECLTYPE_AUTO_RETURN( - base::dispatcher< - true, - base::dispatch_result_t(vs)))...>>:: - template dispatch_at<0>(index, - lib::forward(visitor), - as_base(lib::forward(vs))...)) -#elif !defined(_MSC_VER) || _MSC_VER >= 1910 - DECLTYPE_AUTO_RETURN(base::at( - fdiagonal(vs)))...>::value, - index)(lib::forward(visitor), - as_base(lib::forward(vs))...)) -#else - DECLTYPE_AUTO_RETURN(base::at( - base::make_fdiagonal(vs)))...>(), - index)(lib::forward(visitor), - as_base(lib::forward(vs))...)) -#endif - }; - - struct variant { - private: - template - struct visitor { - template - inline static constexpr bool does_not_handle() { - return lib::is_invocable::value; - } - }; - - template - struct visit_exhaustiveness_check { - static_assert(visitor::template does_not_handle(), - "`visit` requires the visitor to be exhaustive."); - - inline static constexpr DECLTYPE_AUTO invoke(Visitor &&visitor, - Values &&... values) - DECLTYPE_AUTO_RETURN(lib::invoke(lib::forward(visitor), - lib::forward(values)...)) - }; - - template - struct value_visitor { - Visitor &&visitor_; - - template - inline constexpr DECLTYPE_AUTO operator()(Alts &&... alts) const - DECLTYPE_AUTO_RETURN( - visit_exhaustiveness_check< - Visitor, - decltype((lib::forward(alts).value))...>:: - invoke(lib::forward(visitor_), - lib::forward(alts).value...)) - }; - - template - inline static constexpr AUTO make_value_visitor(Visitor &&visitor) - AUTO_RETURN(value_visitor{lib::forward(visitor)}) - - public: - template - inline static constexpr DECLTYPE_AUTO visit_alt(Visitor &&visitor, - Vs &&... vs) - DECLTYPE_AUTO_RETURN(alt::visit_alt(lib::forward(visitor), - lib::forward(vs).impl_...)) - - template - inline static constexpr DECLTYPE_AUTO visit_alt_at(std::size_t index, - Visitor &&visitor, - Vs &&... vs) - DECLTYPE_AUTO_RETURN( - alt::visit_alt_at(index, - lib::forward(visitor), - lib::forward(vs).impl_...)) - - template - inline static constexpr DECLTYPE_AUTO visit_value(Visitor &&visitor, - Vs &&... vs) - DECLTYPE_AUTO_RETURN( - visit_alt(make_value_visitor(lib::forward(visitor)), - lib::forward(vs)...)) - - template - inline static constexpr DECLTYPE_AUTO visit_value_at(std::size_t index, - Visitor &&visitor, - Vs &&... vs) - DECLTYPE_AUTO_RETURN( - visit_alt_at(index, - make_value_visitor(lib::forward(visitor)), - lib::forward(vs)...)) - }; - - } // namespace visitation - - template - struct alt { - using value_type = T; - -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4244) -#endif - template - inline explicit constexpr alt(variant_in_place_t, Args &&... args) - : value(lib::forward(args)...) {} -#ifdef _MSC_VER -#pragma warning(pop) -#endif - - T value; - }; - - template - union recursive_union; - - template - union recursive_union {}; - -#define MPARK_VARIANT_RECURSIVE_UNION(destructible_trait, destructor) \ - template \ - union recursive_union { \ - public: \ - inline explicit constexpr recursive_union(valueless_t) noexcept \ - : dummy_{} {} \ - \ - template \ - inline explicit constexpr recursive_union(in_place_index_t<0>, \ - Args &&... args) \ - : head_(variant_in_place_t{}, lib::forward(args)...) {} \ - \ - template \ - inline explicit constexpr recursive_union(in_place_index_t, \ - Args &&... args) \ - : tail_(in_place_index_t{}, lib::forward(args)...) {} \ - \ - recursive_union(const recursive_union &) = default; \ - recursive_union(recursive_union &&) = default; \ - \ - destructor \ - \ - recursive_union &operator=(const recursive_union &) = default; \ - recursive_union &operator=(recursive_union &&) = default; \ - \ - private: \ - char dummy_; \ - alt head_; \ - recursive_union tail_; \ - \ - friend struct access::recursive_union; \ - } - - MPARK_VARIANT_RECURSIVE_UNION(Trait::TriviallyAvailable, - ~recursive_union() = default;); - MPARK_VARIANT_RECURSIVE_UNION(Trait::Available, - ~recursive_union() {}); - MPARK_VARIANT_RECURSIVE_UNION(Trait::Unavailable, - ~recursive_union() = delete;); - -#undef MPARK_VARIANT_RECURSIVE_UNION - - using index_t = unsigned int; - - template - class base { - public: - inline explicit constexpr base(valueless_t tag) noexcept - : data_(tag), index_(static_cast(-1)) {} - - template - inline explicit constexpr base(in_place_index_t, Args &&... args) - : data_(in_place_index_t{}, lib::forward(args)...), - index_(I) {} - - inline constexpr bool valueless_by_exception() const noexcept { - return index_ == static_cast(-1); - } - - inline constexpr std::size_t index() const noexcept { - return valueless_by_exception() ? variant_npos : index_; - } - - protected: - using data_t = recursive_union; - - friend inline constexpr base &as_base(base &b) { return b; } - friend inline constexpr const base &as_base(const base &b) { return b; } - friend inline constexpr base &&as_base(base &&b) { return lib::move(b); } - friend inline constexpr const base &&as_base(const base &&b) { return lib::move(b); } - - friend inline constexpr data_t &data(base &b) { return b.data_; } - friend inline constexpr const data_t &data(const base &b) { return b.data_; } - friend inline constexpr data_t &&data(base &&b) { return lib::move(b).data_; } - friend inline constexpr const data_t &&data(const base &&b) { return lib::move(b).data_; } - - inline static constexpr std::size_t size() { return sizeof...(Ts); } - - data_t data_; - index_t index_; - - friend struct access::base; - friend struct visitation::base; - }; - - struct dtor { -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4100) -#endif - template - inline void operator()(Alt &alt) const noexcept { alt.~Alt(); } -#ifdef _MSC_VER -#pragma warning(pop) -#endif - }; - -#if !defined(_MSC_VER) || _MSC_VER >= 1910 -#define MPARK_INHERITING_CTOR(type, base) using base::base; -#else -#define MPARK_INHERITING_CTOR(type, base) \ - template \ - inline explicit constexpr type(Args &&... args) \ - : base(lib::forward(args)...) {} -#endif - - template - class destructor; - -#define MPARK_VARIANT_DESTRUCTOR(destructible_trait, definition, destroy) \ - template \ - class destructor, destructible_trait> \ - : public base { \ - using super = base; \ - \ - public: \ - MPARK_INHERITING_CTOR(destructor, super) \ - using super::operator=; \ - \ - destructor(const destructor &) = default; \ - destructor(destructor &&) = default; \ - definition \ - destructor &operator=(const destructor &) = default; \ - destructor &operator=(destructor &&) = default; \ - \ - protected: \ - destroy \ - } - - MPARK_VARIANT_DESTRUCTOR( - Trait::TriviallyAvailable, - ~destructor() = default;, - inline void destroy() noexcept { - this->index_ = static_cast(-1); - }); - - MPARK_VARIANT_DESTRUCTOR( - Trait::Available, - ~destructor() { destroy(); }, - inline void destroy() noexcept { - if (!this->valueless_by_exception()) { - visitation::alt::visit_alt(dtor{}, *this); - } - this->index_ = static_cast(-1); - }); - - MPARK_VARIANT_DESTRUCTOR( - Trait::Unavailable, - ~destructor() = delete;, - inline void destroy() noexcept = delete;); - -#undef MPARK_VARIANT_DESTRUCTOR - - template - class constructor : public destructor { - using super = destructor; - - public: - MPARK_INHERITING_CTOR(constructor, super) - using super::operator=; - - protected: -#ifndef MPARK_GENERIC_LAMBDAS - struct ctor { - template - inline void operator()(LhsAlt &lhs_alt, RhsAlt &&rhs_alt) const { - constructor::construct_alt(lhs_alt, - lib::forward(rhs_alt).value); - } - }; -#endif - - template - inline static T &construct_alt(alt &a, Args &&... args) { - auto *result = ::new (static_cast(lib::addressof(a))) - alt(variant_in_place_t{}, lib::forward(args)...); - return result->value; - } - - template - inline static void generic_construct(constructor &lhs, Rhs &&rhs) { - lhs.destroy(); - if (!rhs.valueless_by_exception()) { - visitation::alt::visit_alt_at( - rhs.index(), -#ifdef MPARK_GENERIC_LAMBDAS - [](auto &lhs_alt, auto &&rhs_alt) { - constructor::construct_alt( - lhs_alt, lib::forward(rhs_alt).value); - } -#else - ctor{} -#endif - , - lhs, - lib::forward(rhs)); - lhs.index_ = rhs.index_; - } - } - }; - - template - class move_constructor; - -#define MPARK_VARIANT_MOVE_CONSTRUCTOR(move_constructible_trait, definition) \ - template \ - class move_constructor, move_constructible_trait> \ - : public constructor> { \ - using super = constructor>; \ - \ - public: \ - MPARK_INHERITING_CTOR(move_constructor, super) \ - using super::operator=; \ - \ - move_constructor(const move_constructor &) = default; \ - definition \ - ~move_constructor() = default; \ - move_constructor &operator=(const move_constructor &) = default; \ - move_constructor &operator=(move_constructor &&) = default; \ - } - - MPARK_VARIANT_MOVE_CONSTRUCTOR( - Trait::TriviallyAvailable, - move_constructor(move_constructor &&that) = default;); - - MPARK_VARIANT_MOVE_CONSTRUCTOR( - Trait::Available, - move_constructor(move_constructor &&that) noexcept( - lib::all::value...>::value) - : move_constructor(valueless_t{}) { - this->generic_construct(*this, lib::move(that)); - }); - - MPARK_VARIANT_MOVE_CONSTRUCTOR( - Trait::Unavailable, - move_constructor(move_constructor &&) = delete;); - -#undef MPARK_VARIANT_MOVE_CONSTRUCTOR - - template - class copy_constructor; - -#define MPARK_VARIANT_COPY_CONSTRUCTOR(copy_constructible_trait, definition) \ - template \ - class copy_constructor, copy_constructible_trait> \ - : public move_constructor> { \ - using super = move_constructor>; \ - \ - public: \ - MPARK_INHERITING_CTOR(copy_constructor, super) \ - using super::operator=; \ - \ - definition \ - copy_constructor(copy_constructor &&) = default; \ - ~copy_constructor() = default; \ - copy_constructor &operator=(const copy_constructor &) = default; \ - copy_constructor &operator=(copy_constructor &&) = default; \ - } - - MPARK_VARIANT_COPY_CONSTRUCTOR( - Trait::TriviallyAvailable, - copy_constructor(const copy_constructor &that) = default;); - - MPARK_VARIANT_COPY_CONSTRUCTOR( - Trait::Available, - copy_constructor(const copy_constructor &that) - : copy_constructor(valueless_t{}) { - this->generic_construct(*this, that); - }); - - MPARK_VARIANT_COPY_CONSTRUCTOR( - Trait::Unavailable, - copy_constructor(const copy_constructor &) = delete;); - -#undef MPARK_VARIANT_COPY_CONSTRUCTOR - - template - class assignment : public copy_constructor { - using super = copy_constructor; - - public: - MPARK_INHERITING_CTOR(assignment, super) - using super::operator=; - - template - inline /* auto & */ auto emplace(Args &&... args) - -> decltype(this->construct_alt(access::base::get_alt(*this), - lib::forward(args)...)) { - this->destroy(); - auto &result = this->construct_alt(access::base::get_alt(*this), - lib::forward(args)...); - this->index_ = I; - return result; - } - - protected: -#ifndef MPARK_GENERIC_LAMBDAS - template - struct assigner { - template - inline void operator()(ThisAlt &this_alt, ThatAlt &&that_alt) const { - self->assign_alt(this_alt, lib::forward(that_alt).value); - } - assignment *self; - }; -#endif - - template - inline void assign_alt(alt &a, Arg &&arg) { - if (this->index() == I) { -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4244) -#endif - a.value = lib::forward(arg); -#ifdef _MSC_VER -#pragma warning(pop) -#endif - } else { - struct { - void operator()(std::true_type) const { - this_->emplace(lib::forward(arg_)); - } - void operator()(std::false_type) const { - this_->emplace(T(lib::forward(arg_))); - } - assignment *this_; - Arg &&arg_; - } impl{this, lib::forward(arg)}; - impl(lib::bool_constant< - std::is_nothrow_constructible::value || - !std::is_nothrow_move_constructible::value>{}); - } - } - - template - inline void generic_assign(That &&that) { - if (this->valueless_by_exception() && that.valueless_by_exception()) { - // do nothing. - } else if (that.valueless_by_exception()) { - this->destroy(); - } else { - visitation::alt::visit_alt_at( - that.index(), -#ifdef MPARK_GENERIC_LAMBDAS - [this](auto &this_alt, auto &&that_alt) { - this->assign_alt( - this_alt, lib::forward(that_alt).value); - } -#else - assigner{this} -#endif - , - *this, - lib::forward(that)); - } - } - }; - - template - class move_assignment; - -#define MPARK_VARIANT_MOVE_ASSIGNMENT(move_assignable_trait, definition) \ - template \ - class move_assignment, move_assignable_trait> \ - : public assignment> { \ - using super = assignment>; \ - \ - public: \ - MPARK_INHERITING_CTOR(move_assignment, super) \ - using super::operator=; \ - \ - move_assignment(const move_assignment &) = default; \ - move_assignment(move_assignment &&) = default; \ - ~move_assignment() = default; \ - move_assignment &operator=(const move_assignment &) = default; \ - definition \ - } - - MPARK_VARIANT_MOVE_ASSIGNMENT( - Trait::TriviallyAvailable, - move_assignment &operator=(move_assignment &&that) = default;); - - MPARK_VARIANT_MOVE_ASSIGNMENT( - Trait::Available, - move_assignment & - operator=(move_assignment &&that) noexcept( - lib::all<(std::is_nothrow_move_constructible::value && - std::is_nothrow_move_assignable::value)...>::value) { - this->generic_assign(lib::move(that)); - return *this; - }); - - MPARK_VARIANT_MOVE_ASSIGNMENT( - Trait::Unavailable, - move_assignment &operator=(move_assignment &&) = delete;); - -#undef MPARK_VARIANT_MOVE_ASSIGNMENT - - template - class copy_assignment; - -#define MPARK_VARIANT_COPY_ASSIGNMENT(copy_assignable_trait, definition) \ - template \ - class copy_assignment, copy_assignable_trait> \ - : public move_assignment> { \ - using super = move_assignment>; \ - \ - public: \ - MPARK_INHERITING_CTOR(copy_assignment, super) \ - using super::operator=; \ - \ - copy_assignment(const copy_assignment &) = default; \ - copy_assignment(copy_assignment &&) = default; \ - ~copy_assignment() = default; \ - definition \ - copy_assignment &operator=(copy_assignment &&) = default; \ - } - - MPARK_VARIANT_COPY_ASSIGNMENT( - Trait::TriviallyAvailable, - copy_assignment &operator=(const copy_assignment &that) = default;); - - MPARK_VARIANT_COPY_ASSIGNMENT( - Trait::Available, - copy_assignment &operator=(const copy_assignment &that) { - this->generic_assign(that); - return *this; - }); - - MPARK_VARIANT_COPY_ASSIGNMENT( - Trait::Unavailable, - copy_assignment &operator=(const copy_assignment &) = delete;); - -#undef MPARK_VARIANT_COPY_ASSIGNMENT - - template - class impl : public copy_assignment> { - using super = copy_assignment>; - - public: - MPARK_INHERITING_CTOR(impl, super) - using super::operator=; - - template - inline void assign(Arg &&arg) { - this->assign_alt(access::base::get_alt(*this), - lib::forward(arg)); - } - - inline void swap(impl &that) { - if (this->valueless_by_exception() && that.valueless_by_exception()) { - // do nothing. - } else if (this->index() == that.index()) { - visitation::alt::visit_alt_at(this->index(), -#ifdef MPARK_GENERIC_LAMBDAS - [](auto &this_alt, auto &that_alt) { - using std::swap; - swap(this_alt.value, - that_alt.value); - } -#else - swapper{} -#endif - , - *this, - that); - } else { - impl *lhs = this; - impl *rhs = lib::addressof(that); - if (lhs->move_nothrow() && !rhs->move_nothrow()) { - std::swap(lhs, rhs); - } - impl tmp(lib::move(*rhs)); -#ifdef MPARK_EXCEPTIONS - // EXTENSION: When the move construction of `lhs` into `rhs` throws - // and `tmp` is nothrow move constructible then we move `tmp` back - // into `rhs` and provide the strong exception safety guarantee. - try { - this->generic_construct(*rhs, lib::move(*lhs)); - } catch (...) { - if (tmp.move_nothrow()) { - this->generic_construct(*rhs, lib::move(tmp)); - } - throw; - } -#else - this->generic_construct(*rhs, lib::move(*lhs)); -#endif - this->generic_construct(*lhs, lib::move(tmp)); - } - } - - private: -#ifndef MPARK_GENERIC_LAMBDAS - struct swapper { - template - inline void operator()(ThisAlt &this_alt, ThatAlt &that_alt) const { - using std::swap; - swap(this_alt.value, that_alt.value); - } - }; -#endif - - inline constexpr bool move_nothrow() const { - return this->valueless_by_exception() || - lib::array{ - {std::is_nothrow_move_constructible::value...} - }[this->index()]; - } - }; - -#undef MPARK_INHERITING_CTOR - - template - struct overload_leaf { - using F = lib::size_constant (*)(T); - operator F() const { return nullptr; } - }; - - template - struct overload_impl { - private: - template - struct impl; - - template - struct impl> : overload_leaf... {}; - - public: - using type = impl>; - }; - - template - using overload = typename overload_impl::type; - - template - using best_match = lib::invoke_result_t, T &&>; - - template - struct is_in_place_index : std::false_type {}; - - template - struct is_in_place_index> : std::true_type {}; - - template - struct is_in_place_type : std::false_type {}; - - template - struct is_in_place_type> : std::true_type {}; - - } // detail_ - - template - class variant { - static_assert(0 < sizeof...(Ts), - "variant must consist of at least one alternative."); - - static_assert(lib::all::value...>::value, - "variant can not have an array type as an alternative."); - - static_assert(lib::all::value...>::value, - "variant can not have a reference type as an alternative."); - - static_assert(lib::all::value...>::value, - "variant can not have a void type as an alternative."); - - public: - template < - typename Front = lib::type_pack_element_t<0, Ts...>, - lib::enable_if_t::value, int> = 0> - inline constexpr variant() noexcept( - std::is_nothrow_default_constructible::value) - : impl_(in_place_index_t<0>{}) {} - - variant(const variant &) = default; - variant(variant &&) = default; - - template < - typename Arg, - typename Decayed = lib::decay_t, - lib::enable_if_t::value, int> = 0, - lib::enable_if_t::value, int> = 0, - lib::enable_if_t::value, int> = 0, - std::size_t I = detail_::best_match::value, - typename T = lib::type_pack_element_t, - lib::enable_if_t::value, int> = 0> - inline constexpr variant(Arg &&arg) noexcept( - std::is_nothrow_constructible::value) - : impl_(in_place_index_t{}, lib::forward(arg)) {} - - template < - std::size_t I, - typename... Args, - typename T = lib::type_pack_element_t, - lib::enable_if_t::value, int> = 0> - inline explicit constexpr variant( - in_place_index_t, - Args &&... args) noexcept(std::is_nothrow_constructible::value) - : impl_(in_place_index_t{}, lib::forward(args)...) {} - - template < - std::size_t I, - typename Up, - typename... Args, - typename T = lib::type_pack_element_t, - lib::enable_if_t &, - Args...>::value, - int> = 0> - inline explicit constexpr variant( - in_place_index_t, - std::initializer_list il, - Args &&... args) noexcept(std:: - is_nothrow_constructible< - T, - std::initializer_list &, - Args...>::value) - : impl_(in_place_index_t{}, il, lib::forward(args)...) {} - - template < - typename T, - typename... Args, - std::size_t I = detail_::find_index_sfinae::value, - lib::enable_if_t::value, int> = 0> - inline explicit constexpr variant( - in_place_type_t, - Args &&... args) noexcept(std::is_nothrow_constructible::value) - : impl_(in_place_index_t{}, lib::forward(args)...) {} - - template < - typename T, - typename Up, - typename... Args, - std::size_t I = detail_::find_index_sfinae::value, - lib::enable_if_t &, - Args...>::value, - int> = 0> - inline explicit constexpr variant( - in_place_type_t, - std::initializer_list il, - Args &&... args) noexcept(std:: - is_nothrow_constructible< - T, - std::initializer_list &, - Args...>::value) - : impl_(in_place_index_t{}, il, lib::forward(args)...) {} - - ~variant() = default; - - variant &operator=(const variant &) = default; - variant &operator=(variant &&) = default; - - template , variant>::value, - int> = 0, - std::size_t I = detail_::best_match::value, - typename T = lib::type_pack_element_t, - lib::enable_if_t<(std::is_assignable::value && - std::is_constructible::value), - int> = 0> - inline variant &operator=(Arg &&arg) noexcept( - (std::is_nothrow_assignable::value && - std::is_nothrow_constructible::value)) { - impl_.template assign(lib::forward(arg)); - return *this; - } - - template < - std::size_t I, - typename... Args, - typename T = lib::type_pack_element_t, - lib::enable_if_t::value, int> = 0> - inline T &emplace(Args &&... args) { - return impl_.template emplace(lib::forward(args)...); - } - - template < - std::size_t I, - typename Up, - typename... Args, - typename T = lib::type_pack_element_t, - lib::enable_if_t &, - Args...>::value, - int> = 0> - inline T &emplace(std::initializer_list il, Args &&... args) { - return impl_.template emplace(il, lib::forward(args)...); - } - - template < - typename T, - typename... Args, - std::size_t I = detail_::find_index_sfinae::value, - lib::enable_if_t::value, int> = 0> - inline T &emplace(Args &&... args) { - return impl_.template emplace(lib::forward(args)...); - } - - template < - typename T, - typename Up, - typename... Args, - std::size_t I = detail_::find_index_sfinae::value, - lib::enable_if_t &, - Args...>::value, - int> = 0> - inline T &emplace(std::initializer_list il, Args &&... args) { - return impl_.template emplace(il, lib::forward(args)...); - } - - inline constexpr bool valueless_by_exception() const noexcept { - return impl_.valueless_by_exception(); - } - - inline constexpr std::size_t index() const noexcept { - return impl_.index(); - } - - template , - Dummy>::value && - lib::dependent_type, - Dummy>::value)...>::value, - int> = 0> - inline void swap(variant &that) noexcept( - lib::all<(std::is_nothrow_move_constructible::value && - lib::is_nothrow_swappable::value)...>::value) { - impl_.swap(that.impl_); - } - - private: - detail_::impl impl_; - - friend struct detail_::access::variant; - friend struct detail_::visitation::variant; - }; - - template - inline constexpr bool holds_alternative(const variant &v) noexcept { - return v.index() == I; - } - - template - inline constexpr bool holds_alternative(const variant &v) noexcept { - return holds_alternative::value>(v); - } - - namespace detail_ { - template - struct generic_get_impl { - constexpr generic_get_impl(int) noexcept {} - - constexpr AUTO_REFREF operator()(V &&v) const - AUTO_REFREF_RETURN( - access::variant::get_alt(lib::forward(v)).value) - }; - - template - inline constexpr AUTO_REFREF generic_get(V &&v) - AUTO_REFREF_RETURN(generic_get_impl( - holds_alternative(v) ? 0 : (throw_bad_variant_access(), 0))( - lib::forward(v))) - } // namespace detail_ - - template - inline constexpr variant_alternative_t> &get( - variant &v) { - return detail_::generic_get(v); - } - - template - inline constexpr variant_alternative_t> &&get( - variant &&v) { - return detail_::generic_get(lib::move(v)); - } - - template - inline constexpr const variant_alternative_t> &get( - const variant &v) { - return detail_::generic_get(v); - } - - template - inline constexpr const variant_alternative_t> &&get( - const variant &&v) { - return detail_::generic_get(lib::move(v)); - } - - template - inline constexpr T &get(variant &v) { - return get::value>(v); - } - - template - inline constexpr T &&get(variant &&v) { - return get::value>(lib::move(v)); - } - - template - inline constexpr const T &get(const variant &v) { - return get::value>(v); - } - - template - inline constexpr const T &&get(const variant &&v) { - return get::value>(lib::move(v)); - } - - namespace detail_ { - - template - inline constexpr /* auto * */ AUTO generic_get_if(V *v) noexcept - AUTO_RETURN(v && holds_alternative(*v) - ? lib::addressof(access::variant::get_alt(*v).value) - : nullptr) - - } // namespace detail_ - - template - inline constexpr lib::add_pointer_t>> - get_if(variant *v) noexcept { - return detail_::generic_get_if(v); - } - - template - inline constexpr lib::add_pointer_t< - const variant_alternative_t>> - get_if(const variant *v) noexcept { - return detail_::generic_get_if(v); - } - - template - inline constexpr lib::add_pointer_t - get_if(variant *v) noexcept { - return get_if::value>(v); - } - - template - inline constexpr lib::add_pointer_t - get_if(const variant *v) noexcept { - return get_if::value>(v); - } - - namespace detail_ { - template - struct convert_to_bool { - template - inline constexpr bool operator()(Lhs &&lhs, Rhs &&rhs) const { - static_assert(std::is_convertible, - bool>::value, - "relational operators must return a type" - " implicitly convertible to bool"); - return lib::invoke( - RelOp{}, lib::forward(lhs), lib::forward(rhs)); - } - }; - } // namespace detail_ - - template - inline constexpr bool operator==(const variant &lhs, - const variant &rhs) { - using detail_::visitation::variant; - using equal_to = detail_::convert_to_bool; -#ifdef MPARK_CPP14_CONSTEXPR - if (lhs.index() != rhs.index()) return false; - if (lhs.valueless_by_exception()) return true; - return variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs); -#else - return lhs.index() == rhs.index() && - (lhs.valueless_by_exception() || - variant::visit_value_at(lhs.index(), equal_to{}, lhs, rhs)); -#endif - } - - template - inline constexpr bool operator!=(const variant &lhs, - const variant &rhs) { - using detail_::visitation::variant; - using not_equal_to = detail_::convert_to_bool; -#ifdef MPARK_CPP14_CONSTEXPR - if (lhs.index() != rhs.index()) return true; - if (lhs.valueless_by_exception()) return false; - return variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs); -#else - return lhs.index() != rhs.index() || - (!lhs.valueless_by_exception() && - variant::visit_value_at(lhs.index(), not_equal_to{}, lhs, rhs)); -#endif - } - - template - inline constexpr bool operator<(const variant &lhs, - const variant &rhs) { - using detail_::visitation::variant; - using less = detail_::convert_to_bool; -#ifdef MPARK_CPP14_CONSTEXPR - if (rhs.valueless_by_exception()) return false; - if (lhs.valueless_by_exception()) return true; - if (lhs.index() < rhs.index()) return true; - if (lhs.index() > rhs.index()) return false; - return variant::visit_value_at(lhs.index(), less{}, lhs, rhs); -#else - return !rhs.valueless_by_exception() && - (lhs.valueless_by_exception() || lhs.index() < rhs.index() || - (lhs.index() == rhs.index() && - variant::visit_value_at(lhs.index(), less{}, lhs, rhs))); -#endif - } - - template - inline constexpr bool operator>(const variant &lhs, - const variant &rhs) { - using detail_::visitation::variant; - using greater = detail_::convert_to_bool; -#ifdef MPARK_CPP14_CONSTEXPR - if (lhs.valueless_by_exception()) return false; - if (rhs.valueless_by_exception()) return true; - if (lhs.index() > rhs.index()) return true; - if (lhs.index() < rhs.index()) return false; - return variant::visit_value_at(lhs.index(), greater{}, lhs, rhs); -#else - return !lhs.valueless_by_exception() && - (rhs.valueless_by_exception() || lhs.index() > rhs.index() || - (lhs.index() == rhs.index() && - variant::visit_value_at(lhs.index(), greater{}, lhs, rhs))); -#endif - } - - template - inline constexpr bool operator<=(const variant &lhs, - const variant &rhs) { - using detail_::visitation::variant; - using less_equal = detail_::convert_to_bool; -#ifdef MPARK_CPP14_CONSTEXPR - if (lhs.valueless_by_exception()) return true; - if (rhs.valueless_by_exception()) return false; - if (lhs.index() < rhs.index()) return true; - if (lhs.index() > rhs.index()) return false; - return variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs); -#else - return lhs.valueless_by_exception() || - (!rhs.valueless_by_exception() && - (lhs.index() < rhs.index() || - (lhs.index() == rhs.index() && - variant::visit_value_at(lhs.index(), less_equal{}, lhs, rhs)))); -#endif - } - - template - inline constexpr bool operator>=(const variant &lhs, - const variant &rhs) { - using detail_::visitation::variant; - using greater_equal = detail_::convert_to_bool; -#ifdef MPARK_CPP14_CONSTEXPR - if (rhs.valueless_by_exception()) return true; - if (lhs.valueless_by_exception()) return false; - if (lhs.index() > rhs.index()) return true; - if (lhs.index() < rhs.index()) return false; - return variant::visit_value_at(lhs.index(), greater_equal{}, lhs, rhs); -#else - return rhs.valueless_by_exception() || - (!lhs.valueless_by_exception() && - (lhs.index() > rhs.index() || - (lhs.index() == rhs.index() && - variant::visit_value_at( - lhs.index(), greater_equal{}, lhs, rhs)))); -#endif - } - - struct monostate {}; - - inline constexpr bool operator<(monostate, monostate) noexcept { - return false; - } - - inline constexpr bool operator>(monostate, monostate) noexcept { - return false; - } - - inline constexpr bool operator<=(monostate, monostate) noexcept { - return true; - } - - inline constexpr bool operator>=(monostate, monostate) noexcept { - return true; - } - - inline constexpr bool operator==(monostate, monostate) noexcept { - return true; - } - - inline constexpr bool operator!=(monostate, monostate) noexcept { - return false; - } - -#ifdef MPARK_CPP14_CONSTEXPR - namespace detail_ { - - inline constexpr bool all(std::initializer_list bs) { - for (bool b : bs) { - if (!b) { - return false; - } - } - return true; - } - - } // namespace detail_ - - template - inline constexpr decltype(auto) visit(Visitor &&visitor, Vs &&... vs) { - return (detail_::all({!vs.valueless_by_exception()...}) - ? (void)0 - : throw_bad_variant_access()), - detail_::visitation::variant::visit_value( - lib::forward(visitor), lib::forward(vs)...); - } -#else - namespace detail_ { - - template - inline constexpr bool all_impl(const lib::array &bs, - std::size_t idx) { - return idx >= N || (bs[idx] && all_impl(bs, idx + 1)); - } - - template - inline constexpr bool all(const lib::array &bs) { - return all_impl(bs, 0); - } - - } // namespace detail_ - - template - inline constexpr DECLTYPE_AUTO visit(Visitor &&visitor, Vs &&... vs) - DECLTYPE_AUTO_RETURN( - (detail_::all( - lib::array{{!vs.valueless_by_exception()...}}) - ? (void)0 - : throw_bad_variant_access()), - detail_::visitation::variant::visit_value(lib::forward(visitor), - lib::forward(vs)...)) -#endif - - template - inline auto swap(variant &lhs, - variant &rhs) noexcept(noexcept(lhs.swap(rhs))) - -> decltype(lhs.swap(rhs)) { - lhs.swap(rhs); - } - - namespace detail_ { - - template - using enabled_type = T; - - namespace hash { - - template - constexpr bool meets_requirements() noexcept { - return std::is_copy_constructible::value && - std::is_move_constructible::value && - lib::is_invocable_r::value; - } - - template - constexpr bool is_enabled() noexcept { - using H = std::hash; - return meets_requirements() && - std::is_default_constructible::value && - std::is_copy_assignable::value && - std::is_move_assignable::value; - } - - } // namespace hash - - } // namespace detail_ - -#undef AUTO -#undef AUTO_RETURN - -#undef AUTO_REFREF -#undef AUTO_REFREF_RETURN - -#undef DECLTYPE_AUTO -#undef DECLTYPE_AUTO_RETURN - -} // namespace c10 - -namespace std { - - template - struct hash, - c10::lib::enable_if_t>()...>::value>>> { - using argument_type = c10::variant; - using result_type = std::size_t; - - inline result_type operator()(const argument_type &v) const { - using c10::detail_::visitation::variant; - std::size_t result = - v.valueless_by_exception() - ? 299792458 // Random value chosen by the universe upon creation - : variant::visit_alt( -#ifdef MPARK_GENERIC_LAMBDAS - [](const auto &alt) { - using alt_type = c10::lib::decay_t; - using value_type = c10::lib::remove_const_t< - typename alt_type::value_type>; - return hash{}(alt.value); - } -#else - hasher{} -#endif - , - v); - return hash_combine(result, hash{}(v.index())); - } - - private: -#ifndef MPARK_GENERIC_LAMBDAS - struct hasher { - template - inline std::size_t operator()(const Alt &alt) const { - using alt_type = c10::lib::decay_t; - using value_type = - c10::lib::remove_const_t; - return hash{}(alt.value); - } - }; -#endif - - static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { - return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); - } - }; - - template <> - struct hash { - using argument_type = c10::monostate; - using result_type = std::size_t; - - inline result_type operator()(const argument_type &) const noexcept { - return 66740831; // return a fundamentally attractive random value. - } - }; - -} // namespace std - -#endif // C10_UTIL_VARIANT_H_ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e16306765f467..37949fdd44652 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -44,14 +44,6 @@ if (INTERN_BUILD_ATEN_OPS) list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS}) list(APPEND Caffe2_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS}) list(APPEND Caffe2_DEPENDENCY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE}) -else() - # Only add "ATen Core", a minimal, easy-to-compile fragment of ATen. - # This codepath should only be exercised by the Android build. - add_subdirectory(../aten/src/ATen/core ATen_core) - list(APPEND Caffe2_CPU_SRCS ${ATen_CORE_SRCS}) - list(APPEND Caffe2_CPU_INCLUDE ${ATen_CORE_INCLUDE}) - list(APPEND Caffe2_CPU_TEST_SRCS ${ATen_CORE_TEST_SRCS}) - # See cmake/Codegen.cmake for header installation endif() # ---[ Caffe2 build @@ -478,21 +470,37 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ) endif() + if (USE_NCCL) + list(APPEND TORCH_SRCS + ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp) + endif() + if (NOT INTERN_BUILD_MOBILE) list(APPEND TORCH_SRCS ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp + ${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp + ${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_context.cpp + ${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/recvrpc_backward.cpp ${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/sendrpc_backward.cpp ${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_remote_call.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_call.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_resp.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_with_autograd.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_proto.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp - ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_rref_proto.cpp - ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_ret.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_resp.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/types.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/utils.cpp ${TORCH_SRC_DIR}/csrc/jit/export.cpp ${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp ${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp + ${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp ) endif() @@ -609,6 +617,11 @@ ELSE() add_library(torch ${Caffe2_CPU_SRCS}) ENDIF() +if (USE_NCCL) + target_link_libraries(torch PRIVATE __caffe2_nccl) + target_compile_definitions(torch PRIVATE USE_NCCL) +endif() + # ========================================================== # formerly-libtorch flags diff --git a/caffe2/c2_aten_srcs.bzl b/caffe2/c2_aten_srcs.bzl index 8e9dd16de21b9..53cb3553ae631 100644 --- a/caffe2/c2_aten_srcs.bzl +++ b/caffe2/c2_aten_srcs.bzl @@ -2,13 +2,11 @@ ATEN_CORE_HEADER_FILES = [ # "aten/src/" prefix is added later "ATen/core/ATenGeneral.h", "ATen/core/blob.h", - "ATen/core/context_base.h", "ATen/core/DimVector.h", "ATen/core/grad_mode.h", "ATen/core/UndefinedTensorImpl.h", ] ATEN_CORE_SRC_FILES = [ - "aten/src/ATen/core/context_base.cpp", "aten/src/ATen/core/grad_mode.cpp", ] diff --git a/caffe2/core/context_base.cc b/caffe2/core/context_base.cc index 99996d9e165b9..35a792ef86908 100644 --- a/caffe2/core/context_base.cc +++ b/caffe2/core/context_base.cc @@ -1,5 +1,20 @@ -#include "context_base.h" +#include + +#include + +namespace at { + +C10_DEFINE_TYPED_REGISTRY( + ContextRegistry, + at::DeviceType, + at::BaseContext, + std::unique_ptr, + at::Device); + +} // namespace at namespace caffe2 { +// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h + } // namespace caffe2 diff --git a/caffe2/core/context_base.h b/caffe2/core/context_base.h index 3a6dfad5b95cc..3ba0d522b5f84 100644 --- a/caffe2/core/context_base.h +++ b/caffe2/core/context_base.h @@ -1,7 +1,168 @@ #pragma once -#include -// For CaffeMap +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + #include "caffe2/core/common.h" #include "caffe2/core/logging.h" #include "caffe2/proto/caffe2_pb.h" + +namespace caffe2 { +class Event; + +} // namespace caffe2 +namespace at { + +class BaseContext; + +/** + * Virtual interface for the Context class in Caffe2. + * + * A Context defines all the necessities to run an operator on a specific + * device. Specific Context classes needs to implement all the pure virtual + * functions in the BaseContext class. + * TODO: add docs after this is finalized. + */ +class CAFFE2_API BaseContext { + public: + virtual ~BaseContext() noexcept {} + + virtual Device device() const = 0; + + /* Sorry for the naming, will get rid of this in future diff */ + virtual DeviceType device_type() const = 0; + + virtual void SwitchToDevice(int /*stream_id*/) = 0; + + inline void SwitchToDevice() { + SwitchToDevice(0); + } + + virtual void WaitEvent(const caffe2::Event& ev) = 0; + + virtual void Record(caffe2::Event* ev, const char* err_msg = nullptr) + const = 0; + + virtual void FinishDeviceComputation() = 0; + + // This used to be arbitrary cross-device copy, but it turns out everyone + // did direct CPU-X copy, so we just make three functions for it (to avoid + // double dispatch). This will get obsoleted by C10. where copies + // will be proper operators (and get to rely on multiple dispatch there.) + virtual void CopyBytesSameDevice( + size_t nbytes, + const void* src, + void* dst) = 0; + + virtual void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) = 0; + + virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0; + + template + inline void CopySameDevice(size_t n, const T* src, T* dst) { + static_assert( + std::is_fundamental::value, + "CopySameDevice requires fundamental types"); + CopyBytesSameDevice( + n * sizeof(T), static_cast(src), static_cast(dst)); + } + + template + inline void CopyFromCPU(size_t n, const T* src, T* dst) { + static_assert( + std::is_fundamental::value, + "CopyFromCPU requires fundamental types"); + CopyBytesFromCPU( + n * sizeof(T), static_cast(src), static_cast(dst)); + } + + template + inline void CopyToCPU(size_t n, const T* src, T* dst) { + static_assert( + std::is_fundamental::value, "CopyToCPU requires fundamental types"); + CopyBytesToCPU( + n * sizeof(T), static_cast(src), static_cast(dst)); + } + + virtual bool SupportsNonFundamentalTypes() const { + return false; + } + + inline void EnforceMetaCopyOK() { + AT_ASSERTM( + SupportsNonFundamentalTypes(), "Context requires fundamental types"); + } + + void CopyItemsSameDevice( + const caffe2::TypeMeta& meta, + size_t n, + const void* src, + void* dst) { + if (meta.copy()) { + EnforceMetaCopyOK(); + meta.copy()(src, dst, n); + } else { + CopyBytesSameDevice(n * meta.itemsize(), src, dst); + } + } + + void CopyItemsFromCPU( + const caffe2::TypeMeta& meta, + size_t n, + const void* src, + void* dst) { + if (meta.copy()) { + EnforceMetaCopyOK(); + meta.copy()(src, dst, n); + } else { + CopyBytesFromCPU(n * meta.itemsize(), src, dst); + } + } + + void CopyItemsToCPU( + const caffe2::TypeMeta& meta, + size_t n, + const void* src, + void* dst) { + if (meta.copy()) { + EnforceMetaCopyOK(); + meta.copy()(src, dst, n); + } else { + CopyBytesToCPU(n * meta.itemsize(), src, dst); + } + } +}; + +// Context constructor registry +C10_DECLARE_TYPED_REGISTRY( + ContextRegistry, + at::DeviceType, + at::BaseContext, + std::unique_ptr, + at::Device); + +#define REGISTER_CONTEXT(type, ...) \ + C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__) + +inline std::unique_ptr CreateContext( + const at::Device& device) { + return at::ContextRegistry()->Create(device.type(), device); +} + +} // namespace at + +namespace caffe2 { + +using at::BaseContext; +using at::CreateContext; +} // namespace caffe2 diff --git a/caffe2/core/export_c10_op_to_caffe2.h b/caffe2/core/export_c10_op_to_caffe2.h index b9935a77a3d6b..80c1ff8a7b07d 100644 --- a/caffe2/core/export_c10_op_to_caffe2.h +++ b/caffe2/core/export_c10_op_to_caffe2.h @@ -1,13 +1,16 @@ #pragma once +#include +#include +#include "caffe2/core/operator.h" + // TODO Also register c10 operators on mobile -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) #include #include #include #include #include -#include "caffe2/core/operator.h" #include "caffe2/core/export_caffe2_op_to_c10.h" namespace caffe2 { diff --git a/caffe2/core/export_caffe2_op_to_c10.h b/caffe2/core/export_caffe2_op_to_c10.h index 84a5f70fb4ef2..11db8bb51218f 100644 --- a/caffe2/core/export_caffe2_op_to_c10.h +++ b/caffe2/core/export_caffe2_op_to_c10.h @@ -1,6 +1,8 @@ #pragma once -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#include + +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) #include #include #include @@ -93,7 +95,7 @@ void call_caffe2_op_from_c10( } inline FunctionSchema make_function_schema_for_c10(const char* schema_str) { -#if defined(CAFFE2_IS_XPLAT_BUILD) +#if defined(CAFFE2_IS_XPLAT_BUILD) || defined(C10_MOBILE) throw std::logic_error("We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build."); #else c10::FunctionSchema parsed_schema = torch::jit::parseSchema(schema_str); diff --git a/caffe2/core/net_async_base.h b/caffe2/core/net_async_base.h index 01eca67ddc77e..a2304d89fada2 100644 --- a/caffe2/core/net_async_base.h +++ b/caffe2/core/net_async_base.h @@ -1,6 +1,7 @@ #ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_ #define CAFFE2_CORE_NET_ASYNC_BASE_H_ +#include #include "c10/core/thread_pool.h" #include "c10/util/Registry.h" #include "caffe2/core/common.h" @@ -13,7 +14,6 @@ #include "caffe2/proto/caffe2_pb.h" #include "caffe2/proto/prof_dag.pb.h" #include "caffe2/utils/proto_utils.h" -#include C10_DECLARE_int(caffe2_streams_per_gpu); C10_DECLARE_int(caffe2_net_async_max_gpus); diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc index 225b01d1034d9..cb98352a938da 100644 --- a/caffe2/core/operator.cc +++ b/caffe2/core/operator.cc @@ -14,7 +14,7 @@ #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/proto_utils.h" #include "caffe2/utils/string_utils.h" -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) #include #endif @@ -58,7 +58,7 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws) device_option_( operator_def.has_device_option() ? operator_def.device_option() : DeviceOption()), -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) newstyle_outputs_(), #endif input_size_(operator_def.input_size()), @@ -86,7 +86,7 @@ OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws) type_ = operator_def.type(); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) namespace { int C10_UNUSED // Suppress unused function warning on mobile. @@ -793,7 +793,7 @@ std::function GetOperatorLogger() { c10::optional OperatorBase::argumentIndexWithName( const std::string& name) const { -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) return getFunctionSchema().argumentIndexWithName(name); #else CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 4296209f512c8..a60149c40d9a7 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -28,7 +28,7 @@ #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/proto_utils.h" -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) #include #include #endif @@ -59,7 +59,7 @@ class CAFFE2_API OperatorBase : public Observable { * Alternatively, inputs can be one tensor list ivalue followed by non-tensors * to represent operators with a variable number of inputs. */ -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) explicit OperatorBase( const c10::FunctionSchema& schema, std::vector inputs, @@ -72,7 +72,7 @@ class CAFFE2_API OperatorBase : public Observable { * New operators should be instantiated with FunctionSchema */ bool isLegacyOperator() const { -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) return !fn_schema_; #else return true; @@ -81,7 +81,7 @@ class CAFFE2_API OperatorBase : public Observable { const c10::FunctionSchema& getFunctionSchema() const { CAFFE_ENFORCE(!isLegacyOperator()); -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) return *fn_schema_.get(); #else CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); @@ -107,7 +107,7 @@ class CAFFE2_API OperatorBase : public Observable { return ArgumentHelper::GetSingleArgument( *operator_def_, name, default_value); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) auto index = argumentIndexWithName(name); CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); const auto& value = newstyle_inputs_[index.value()]; @@ -123,7 +123,7 @@ class CAFFE2_API OperatorBase : public Observable { return ArgumentHelper::HasSingleArgumentOfType( *operator_def_, name); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) template inline vector GetVectorFromIValueList(const c10::IValue& value) const { return c10::impl::toVector(value.template to>()); @@ -180,7 +180,7 @@ class CAFFE2_API OperatorBase : public Observable { throw enf; } } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) DCHECK_LT(0, newstyle_inputs_.size()); IValue ival; if (newstyle_inputs_[0].isTensorList()) { @@ -230,7 +230,7 @@ class CAFFE2_API OperatorBase : public Observable { // When you get a Tensor here it is not fully initialized return BlobGetMutableTensor(outputs_.at(idx), type); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) at::Tensor output = newstyle_outputs_[idx]; Tensor tensor = caffe2::Tensor(output); if (!tensor.defined() || tensor.GetDeviceType() != type) { @@ -260,7 +260,7 @@ class CAFFE2_API OperatorBase : public Observable { void SetOutputTensor(int idx, Tensor tensor) { if (!isLegacyOperator()) { -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) newstyle_outputs_[idx] = at::Tensor(tensor); // also update the tensor in the hack @@ -289,7 +289,7 @@ class CAFFE2_API OperatorBase : public Observable { "device must be provided in options."); return BlobGetMutableTensor(outputs_.at(idx), dims, options); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) at::Tensor output = newstyle_outputs_[idx]; Tensor tensor = GetSizedTensorWithOptions(caffe2::Tensor(output), dims, options); @@ -413,7 +413,7 @@ class CAFFE2_API OperatorBase : public Observable { if (isLegacyOperator()) { return outputs_.size(); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) return newstyle_outputs_.size(); #else CAFFE_THROW("Non-legacy operators are not legal in xplat/caffe2"); @@ -628,7 +628,7 @@ class CAFFE2_API OperatorBase : public Observable { return helper_; } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) c10::List move_newstyle_outputs() && { return std::move(newstyle_outputs_); } @@ -646,7 +646,7 @@ class CAFFE2_API OperatorBase : public Observable { vector inputs_; vector outputs_; // Preferrably use c10::optional, but nvcc doesn't work -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) std::unique_ptr fn_schema_; vector newstyle_inputs_; c10::List newstyle_outputs_; @@ -710,7 +710,7 @@ inline NetDef OperatorBase::GetSingleArgument( return NetDef(); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) template <> inline vector OperatorBase::GetVectorFromIValueList( const c10::IValue& value) const { @@ -794,7 +794,7 @@ inline vector OperatorBase::GetRepeatedArgument( return ArgumentHelper::GetRepeatedArgument( *operator_def_, name, default_value); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) auto index = argumentIndexWithName(name); CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); const auto& value = newstyle_inputs_[index.value()]; @@ -815,7 +815,7 @@ inline vector OperatorBase::GetRepeatedArgument( return ArgumentHelper::GetRepeatedArgument( *operator_def_, name, default_value); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) auto index = argumentIndexWithName(name); CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name); const auto& value = newstyle_inputs_[index.value()]; @@ -843,7 +843,7 @@ class Operator : public OperatorBase { // constructors will run on that device. context_.SwitchToDevice(); } -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) explicit Operator( const c10::FunctionSchema& fn_schema, std::vector inputs, diff --git a/caffe2/core/tensor.cc b/caffe2/core/tensor.cc index afdc44424ee31..9fe81027ad93e 100644 --- a/caffe2/core/tensor.cc +++ b/caffe2/core/tensor.cc @@ -202,9 +202,11 @@ void Tensor::enforce_invariants() { throw std::runtime_error("TensorImpl with nullptr is not supported"); } // TODO: only check `!impl_->requires_grad()` after Variable and Tensor are merged +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) CAFFE_ENFORCE( !impl_->is_variable() || !(impl_->requires_grad() && at::GradMode::is_enabled()), "Caffe2 tensor wrapper doesn't support autograd variables that require grad"); +#endif CAFFE_ENFORCE_EQ( impl_->layout(), at::kStrided, diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 41b2b421c4051..1f1ab8d06e56d 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -1,18 +1,18 @@ #ifndef CAFFE2_CORE_TENSOR_H_ #define CAFFE2_CORE_TENSOR_H_ +#include #include "caffe2/core/storage.h" #include "caffe2/core/tensor_impl.h" #include #include -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) #include "ATen/core/Tensor.h" +#include #endif #include -#include - namespace caffe2 { using at::UndefinedTensorImpl; @@ -119,7 +119,7 @@ class CAFFE2_API Tensor final { * The tensor will share the same instance (data, strides, sizes, etc) but * a different subset of APIs would be available */ -#if !defined(CAFFE2_IS_XPLAT_BUILD) +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) explicit Tensor(at::Tensor tensor) : impl_(std::move(tensor.impl_)) { enforce_invariants(); @@ -196,7 +196,9 @@ class CAFFE2_API Tensor final { */ void CopyFrom(const Tensor& src, bool async = false) { // TODO: only check `!impl_->requires_grad()` after Variable and Tensor are merged +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) AT_ASSERT(!impl_->is_variable() || !(impl_->requires_grad() && at::GradMode::is_enabled())); +#endif AT_ASSERTM( src.impl_->is_contiguous(), "Right now only copy of contiguous source Tensor is supported."); diff --git a/caffe2/core/tensor_impl.h b/caffe2/core/tensor_impl.h index 64f6db326e315..6c5316da83d89 100644 --- a/caffe2/core/tensor_impl.h +++ b/caffe2/core/tensor_impl.h @@ -2,8 +2,8 @@ #include #include -#include -#include +#include +#include namespace caffe2 { using at::ToVectorint64_t; diff --git a/caffe2/core/test_utils.h b/caffe2/core/test_utils.h index d60cdaeecd5f7..68214b235b4c6 100644 --- a/caffe2/core/test_utils.h +++ b/caffe2/core/test_utils.h @@ -18,7 +18,9 @@ namespace caffe2 { namespace testing { // Asserts that the values of two tensors are the same. -CAFFE2_API void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2); +CAFFE2_API void assertTensorEquals( + const TensorCPU& tensor1, + const TensorCPU& tensor2); // Asserts that two float values are close within epsilon. CAFFE2_API void assertNear(float value1, float value2, float epsilon); @@ -74,6 +76,22 @@ CAFFE2_API caffe2::OperatorDef* createOperator( const std::vector& outputs, caffe2::NetDef* net); +// Fill a buffer with randomly generated numbers given range [min, max) +// T can only be float, double or long double +template +void randomFill( + RealType* data, + size_t size, + const double min = 0.0, + const double max = 1.0) { + std::mt19937 gen(42); + std::uniform_real_distribution dis( + static_cast(min), static_cast(max)); + for (size_t i = 0; i < size; i++) { + data[i] = dis(gen); + } +} + // Fill data from a vector to a tensor. template void fillTensor( diff --git a/caffe2/operators/CMakeLists.txt b/caffe2/operators/CMakeLists.txt index ecdb951fcdaf7..0b3284f42fd0f 100644 --- a/caffe2/operators/CMakeLists.txt +++ b/caffe2/operators/CMakeLists.txt @@ -42,23 +42,25 @@ exclude(tmp "${tmp}" ${tmp_cudnn}) set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp}) # Add all files in experimental -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/flatten_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/averaged_loss_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/mul_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/relu_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/expand_dims_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/filler_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/sparse_lengths_sum_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/sigmoid_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/cast_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/stop_gradient_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/batch_gather_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/concat_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/batch_matmul_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/fc_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/enforce_finite_cpu.cc) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/add_cpu.cc) +if (NOT INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/flatten_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/averaged_loss_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/mul_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/relu_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/expand_dims_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/filler_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/sparse_lengths_sum_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/sigmoid_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/cast_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/stop_gradient_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/batch_gather_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/concat_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/batch_matmul_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/fc_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/enforce_finite_cpu.cc) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/add_cpu.cc) +endif() set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${tmp}) # exclude test files and gpu files diff --git a/caffe2/operators/reduction_ops.h b/caffe2/operators/reduction_ops.h index b6ca336f2a56f..d4fbef0359b45 100644 --- a/caffe2/operators/reduction_ops.h +++ b/caffe2/operators/reduction_ops.h @@ -19,11 +19,13 @@ class SumElementsOp : public Operator { average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average) : Operator(operator_def, ws), average_(average) {} +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) explicit SumElementsOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs) : Operator(schema, std::move(inputs), std::move(outputs)), average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs, bool average) : Operator(schema, std::move(inputs), std::move(outputs)), average_(average) {} +#endif ~SumElementsOp() {} bool RunOnDevice() override { @@ -85,11 +87,13 @@ class SumElementsGradientOp : public Operator { average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws, bool average) : Operator(operator_def, ws), average_(average) {} +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) explicit SumElementsGradientOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs) : Operator(schema, std::move(inputs), std::move(outputs)), average_(this->template GetSingleArgument("average", false)) {} explicit SumElementsGradientOp(const c10::FunctionSchema& schema, std::vector inputs, std::vector outputs, bool average) : Operator(schema, std::move(inputs), std::move(outputs)), average_(average) {} +#endif ~SumElementsGradientOp() {} bool RunOnDevice() override; diff --git a/caffe2/operators/resize_3d_op.cc b/caffe2/operators/resize_3d_op.cc new file mode 100644 index 0000000000000..4898711108a35 --- /dev/null +++ b/caffe2/operators/resize_3d_op.cc @@ -0,0 +1,224 @@ +#include "caffe2/operators/resize_3d_op.h" + +#include "caffe2/utils/math.h" + +#ifdef CAFFE2_USE_MKLDNN +#include "caffe2/ideep/operators/operator_fallback_ideep.h" +#include "caffe2/ideep/utils/ideep_operator.h" +#endif + +namespace caffe2 { + +void resizeNearest3DNCHW2x( + int batch_size, + int num_channels, + int temporal_scale, + int input_frames, + int input_height, + int input_width, + const float* input, + float* output) { + const int output_frames = input_frames * temporal_scale; + const int output_height = input_height * 2; + const int output_width = input_width * 2; + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < num_channels; ++c) { + for (int f = 0; f < output_frames; ++f ) { + const int in_f = f / temporal_scale; + for (int y = 0; y < output_height; ++y) { + const int in_y = y / 2; + + for (int x = 0; x < input_width; ++x) { + const float v = + input[((in_f * input_height) + in_y) * input_width + x]; + const int oidx = y * output_width + x * 2; + output[oidx + 0] = v; + output[oidx + 1] = v; + } + } + output += output_height * output_width; + } + input += input_frames * input_height * input_width; + } + } +} + +template <> +bool ResizeNearest3DOp::RunOnDeviceWithOrderNCHW() { + const auto& X = Input(0); + const auto XDims = X.sizes(); + CAFFE_ENFORCE_EQ(5, XDims.size()); + + const int batch_size = X.dim32(0), num_channels = X.dim32(1), + input_frames = X.dim32(2), input_height = X.dim32(3), + input_width = X.dim32(4); + + CAFFE_ENFORCE_EQ(InputSize(), 1); + + int output_frames = input_frames * temporal_scale_; + int output_height = input_height * height_scale_; + int output_width = input_width * width_scale_; + auto* Y = Output( + 0, + {batch_size, num_channels, output_frames, output_height, output_width}, + at::dtype()); + + const float* Xdata = X.data(); + float* Ydata = Y->template mutable_data(); + + // Specialized implementation for fast 2x upsampling + if (width_scale_ == 2.0 && height_scale_ == 2.0) { + CAFFE_ENFORCE(temporal_scale_ == 1 || temporal_scale_ == 2, + "temporal_scale must be either 1 or 2"); + + resizeNearest3DNCHW2x( + batch_size, num_channels, temporal_scale_, input_frames, input_height, + input_width, Xdata, Ydata); + return true; + } + + CAFFE_THROW("Not implemented when height- and width scale are not 2"); +} + +template <> +bool ResizeNearest3DOp::RunOnDevice() { + switch (order_) { + case StorageOrder::NHWC: + CAFFE_THROW("Not implemented for storage order: ", order_); + case StorageOrder::NCHW: + return RunOnDeviceWithOrderNCHW(); + default: + CAFFE_THROW("Unknown Storage order: ", order_); + } +} + +template <> +bool ResizeNearest3DGradientOp::RunOnDeviceWithOrderNCHW() { + const auto& dY = Input(0); + const auto& X = Input(1); + + const auto inputDims = dY.sizes(); + CAFFE_ENFORCE_EQ(5, inputDims.size()); + const int batch_size = dY.dim32(0), num_channels = dY.dim32(1), + input_frames = dY.dim32(2), input_height = dY.dim32(3), + input_width = dY.dim32(4); + + const int output_frames = X.dim32(2); + const int output_height = X.dim32(3); + const int output_width = X.dim32(4); + + CAFFE_ENFORCE_EQ(InputSize(), 2); + + auto* dX = Output( + 0, + {batch_size, num_channels, output_frames, output_height, output_width}, + at::dtype()); + math::Set( + dX->numel(), 0.0f, dX->template mutable_data(), &context_); + + const float* dYdata = dY.data(); + float* dXdata = dX->template mutable_data(); + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < num_channels; ++c) { + for (int f = 0; f < input_frames; ++f) { + const int out_f = + std::min((int)(f / temporal_scale_), output_frames - 1); + for (int y = 0; y < input_height; ++y) { + const int out_y = + std::min((int)(y / height_scale_), (output_height - 1)); + for (int x = 0; x < input_width; ++x) { + const int out_x = + std::min((int)(x / width_scale_), (output_width - 1)); + dXdata[(out_f * output_height + out_y) * output_width + out_x] += + dYdata[(f * input_height + y) * input_width + x]; + } + } + } + dYdata += input_frames * input_height * input_width; + dXdata += output_frames * output_height * output_width; + } + } + + return true; +} + +template <> +bool ResizeNearest3DGradientOp::RunOnDevice() { + switch (order_) { + case StorageOrder::NHWC: + CAFFE_THROW("Not implemented for storage order: ", order_); + case StorageOrder::NCHW: + return RunOnDeviceWithOrderNCHW(); + default: + CAFFE_THROW("Unknown Storage order: ", order_); + } +} +REGISTER_CPU_OPERATOR(ResizeNearest3D, ResizeNearest3DOp); +REGISTER_CPU_GRADIENT_OPERATOR( + ResizeNearest3DGradient, + ResizeNearest3DGradientOp); + +#ifdef CAFFE2_USE_MKLDNN +REGISTER_IDEEP_OPERATOR( + ResizeNearest3D, + IDEEPFallbackOp>); +#endif + +// Input: X, output: Y +OPERATOR_SCHEMA(ResizeNearest3D) + .NumInputs(1) + .NumOutputs(1) + .Arg("temporal_scale", "Scale along temporal dimension") + .Arg("width_scale", "Scale along width dimension") + .Arg("height_scale", "Scale along height dimension") + .SetDoc(R"DOC( +Resizes the spatial dimensions of the input tensor using nearest neighbor +interpolation. The `width_scale` and `height_scale` arguments +control the size of the output, which is given by: +output_width = floor(input_width * width_scale) +output_height = floor(output_height * height_scale) +Assumptions: + - Only resize height and width + - Both width_scale and height_scale scale are 2 +)DOC") + .Input(0, "X", "Input tensor") + .Output(0, "Y", "Output tensor"); + +// Input: dY, output: dX +GRADIENT_OPERATOR_SCHEMA(ResizeNearest3DGradient) + .NumInputs(2) + .NumOutputs(1) + .Arg("temporal_scale", "Scale along temporal dimension") + .Arg("width_scale", "Scale along width dimension") + .Arg("height_scale", "Scale along height dimension"); + +class GetResizeNearest3DGradient : public GradientMakerBase { + using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { + return SingleGradientDef( + "ResizeNearest3DGradient", + "", + vector{GO(0), I(0)}, + vector{GI(0)}); + } +}; +REGISTER_GRADIENT(ResizeNearest3D, GetResizeNearest3DGradient); + +} // namespace caffe2 + +using ResizeNearest3DOpFloatCPU = + caffe2::ResizeNearest3DOp; + +// clang-format off +C10_EXPORT_CAFFE2_OP_TO_C10_CPU( + ResizeNearest3D, + "_caffe2::ResizeNearest3D(" + "Tensor X, " + "str order, " + "float temporal_scale, " + "float width_scale, " + "float height_scale" + ") -> (Tensor Y)", + ResizeNearest3DOpFloatCPU); +// clang-format on diff --git a/caffe2/operators/resize_3d_op.cu b/caffe2/operators/resize_3d_op.cu new file mode 100644 index 0000000000000..378e3bec74a7f --- /dev/null +++ b/caffe2/operators/resize_3d_op.cu @@ -0,0 +1,219 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/resize_3d_op.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +namespace { + +__global__ void NearestNeighbor3DKernel( + const int size, + const int num_channels, + const int input_frames, + const int input_height, + const int input_width, + const int output_frames, + const int output_height, + const int output_width, + const float temporal_scale, + const float height_scale, + const float width_scale, + const float* X, + float* Y) { + CUDA_1D_KERNEL_LOOP(index, size) { + int indexTemp = index; + const int w = indexTemp % output_width; + indexTemp /= output_width; + const int h = indexTemp % output_height; + indexTemp /= output_height; + const int f = indexTemp % output_frames; + indexTemp /= output_frames; + const int c = indexTemp % num_channels; + indexTemp /= num_channels; + const int n = indexTemp; + + const int in_f = fminf(f / temporal_scale, input_frames - 1); + const int in_y = fminf(h / height_scale, input_height - 1); + const int in_x = fminf(w / width_scale, input_width - 1); + Y[index] = + X[(((n * num_channels + c) * input_frames + in_f) * input_height + in_y) + * input_width + in_x]; + } +} + +__global__ void NearestNeighbor3DGradientKernel( + const int size, + const int num_channels, + const int input_frames, + const int input_height, + const int input_width, + const int output_frames, + const int output_height, + const int output_width, + const float temporal_scale, + const float height_scale, + const float width_scale, + const float* dY, + float* dX) { + CUDA_1D_KERNEL_LOOP(index, size) { + int indexTemp = index; + const int x = indexTemp % input_width; + indexTemp /= input_width; + const int y = indexTemp % input_height; + indexTemp /= input_height; + const int f = indexTemp % input_frames; + indexTemp /= input_frames; + const int c = indexTemp % num_channels; + indexTemp /= num_channels; + const int n = indexTemp; + + const int out_f = fminf(f / temporal_scale, output_frames - 1); + const int out_y = fminf(y / height_scale, output_height - 1); + const int out_x = fminf(x / width_scale, output_width - 1); + const int out_index = + (((n * num_channels + c) * output_frames + out_f) * output_height + + out_y) * output_width + out_x; +#if __CUDA_ARCH__ >= 350 + atomicAdd(dX + out_index, __ldg(dY + index)); +#else + atomicAdd(dX + out_index, *(dY + index)); +#endif + } +} + +} // namespace + + +template <> +bool ResizeNearest3DOp::RunOnDeviceWithOrderNCHW() { + const auto& X = Input(0); + + const auto inputDims = X.sizes(); + CAFFE_ENFORCE_EQ(5, inputDims.size()); + const int batch_size = X.dim32(0), num_channels = X.dim32(1), + input_frames = X.dim32(2), input_height = X.dim32(3), + input_width = X.dim32(4); + + CAFFE_ENFORCE_EQ(InputSize(), 1); + + int output_frames = input_frames * temporal_scale_; + int output_height = input_height * height_scale_; + int output_width = input_width * width_scale_; + auto* Y = Output( + 0, + {batch_size, num_channels, output_frames, output_height, output_width}, + at::dtype()); + + const auto size = Y->numel(); + NearestNeighbor3DKernel<<< + CAFFE_GET_BLOCKS(size), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + size, + num_channels, + input_frames, + input_height, + input_width, + output_frames, + output_height, + output_width, + temporal_scale_, + height_scale_, + width_scale_, + X.data(), + Y->template mutable_data()); + + return true; +} + +template <> +bool ResizeNearest3DOp::RunOnDevice() { + switch (order_) { + case StorageOrder::NHWC: + CAFFE_THROW("Not implemented for storage order: ", order_); + case StorageOrder::NCHW: + return RunOnDeviceWithOrderNCHW(); + default: + CAFFE_THROW("Unknown Storage order: ", order_); + } +} + + +template <> +bool ResizeNearest3DGradientOp::RunOnDeviceWithOrderNCHW() { + const auto& dY = Input(0); + const auto& X = Input(1); + + const auto inputDims = dY.sizes(); + CAFFE_ENFORCE_EQ(5, inputDims.size()); + const int batch_size = dY.dim32(0), num_channels = dY.dim32(1), + input_frames = dY.dim32(2), input_height = dY.dim32(3), + input_width = dY.dim32(4); + + // X,dim32(2) can be different from int(input_frames / temporal_scale_) + // We choose to compute output_frames=int(input_frames / temporal_scale_) + + // const int output_frames = X,dim32(2); + // const int output_height = X.dim32(3); + // const int output_width = X.dim32(4); + + const int output_frames = int(input_frames / temporal_scale_); + const int output_height = int(input_height / height_scale_); + const int output_width = int(input_width / width_scale_); + + CAFFE_ENFORCE_EQ(InputSize(), 2); + + auto* dX = Output( + 0, + {batch_size, num_channels, output_frames, output_height, output_width}, + at::dtype()); + math::Set( + dX->numel(), 0.0f, dX->template mutable_data(), &context_); + + const auto size = dY.numel(); + NearestNeighbor3DGradientKernel<<< + CAFFE_GET_BLOCKS(size), + CAFFE_CUDA_NUM_THREADS, + 0, + context_.cuda_stream()>>>( + size, + num_channels, + input_frames, + input_height, + input_width, + output_frames, + output_height, + output_width, + temporal_scale_, + height_scale_, + width_scale_, + dY.data(), + dX->template mutable_data()); + + return true; +} + +template <> +bool ResizeNearest3DGradientOp::RunOnDevice() { + switch (order_) { + case StorageOrder::NHWC: + CAFFE_THROW("Not implemented for storage order: ", order_); + case StorageOrder::NCHW: + return RunOnDeviceWithOrderNCHW(); + default: + CAFFE_THROW("Unknown Storage order: ", order_); + } +} + +REGISTER_CUDA_OPERATOR(ResizeNearest3D, ResizeNearest3DOp); +REGISTER_CUDA_OPERATOR( + ResizeNearest3DGradient, + ResizeNearest3DGradientOp); + +} // namespace caffe2 + +using ResizeNearest3DOpFloatCUDA = + caffe2::ResizeNearest3DOp; + +C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(ResizeNearest3D, ResizeNearest3DOpFloatCUDA); diff --git a/caffe2/operators/resize_3d_op.h b/caffe2/operators/resize_3d_op.h new file mode 100644 index 0000000000000..1495940ef6637 --- /dev/null +++ b/caffe2/operators/resize_3d_op.h @@ -0,0 +1,92 @@ +#pragma once + +#include "caffe2/core/export_caffe2_op_to_c10.h" +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(ResizeNearest3D); + + +namespace caffe2 { + +template +class ResizeNearest3DOp final : public Operator { +public: + template + explicit ResizeNearest3DOp(Args&&... args) + : Operator(std::forward(args)...), + temporal_scale_(1), + height_scale_(1), + width_scale_(1), + order_(StringToStorageOrder( + this->template GetSingleArgument("order", "NCHW"))) { + if (HasArgument("temporal_scale")) { + temporal_scale_ = static_cast( + this->template GetSingleArgument("temporal_scale", 1)); + } + if (HasArgument("height_scale")) { + height_scale_ = static_cast( + this->template GetSingleArgument("height_scale", 1)); + } + if (HasArgument("width_scale")) { + width_scale_ = static_cast( + this->template GetSingleArgument("width_scale", 1)); + } + + CAFFE_ENFORCE_GT(temporal_scale_, 0); + CAFFE_ENFORCE_GT(height_scale_, 0); + CAFFE_ENFORCE_GT(width_scale_, 0); + + CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC); + } + + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + bool RunOnDeviceWithOrderNCHW(); + + protected: + T temporal_scale_; + T height_scale_; + T width_scale_; + StorageOrder order_; +}; + +template +class ResizeNearest3DGradientOp final : public Operator { + public: + template + explicit ResizeNearest3DGradientOp(Args&&... args) + : Operator(std::forward(args)...), + temporal_scale_(1), + height_scale_(1), + width_scale_(1), + order_(StringToStorageOrder( + this->template GetSingleArgument("order", "NCHW"))) { + temporal_scale_ = static_cast( + this->template GetSingleArgument("temporal_scale", 1)); + height_scale_ = static_cast( + this->template GetSingleArgument("height_scale", 1)); + width_scale_ = static_cast( + this->template GetSingleArgument("width_scale", 1)); + + CAFFE_ENFORCE_GT(temporal_scale_, 0); + CAFFE_ENFORCE_GT(height_scale_, 0); + CAFFE_ENFORCE_GT(width_scale_, 0); + + CAFFE_ENFORCE(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC); + } + + USE_OPERATOR_CONTEXT_FUNCTIONS; + + bool RunOnDevice() override; + bool RunOnDeviceWithOrderNCHW(); + + protected: + T temporal_scale_; + T height_scale_; + T width_scale_; + StorageOrder order_; +}; + +} // namespace caffe2 diff --git a/caffe2/python/layers/batch_lr_loss.py b/caffe2/python/layers/batch_lr_loss.py index f0d5a81f484fb..4ab03d80f928d 100644 --- a/caffe2/python/layers/batch_lr_loss.py +++ b/caffe2/python/layers/batch_lr_loss.py @@ -31,6 +31,8 @@ def __init__( uncertainty_penalty=1.0, focal_gamma=0.0, stop_grad_in_focal_factor=False, + task_gamma=1.0, + task_gamma_lb=0.1, **kwargs ): super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs) @@ -75,6 +77,44 @@ def __init__( self.focal_gamma = focal_gamma self.stop_grad_in_focal_factor = stop_grad_in_focal_factor + self.apply_exp_decay = False + if task_gamma < 1.0: + self.apply_exp_decay = True + self.task_gamma_cur = self.create_param( + param_name=('%s_task_gamma_cur' % self.name), + shape=[1], + initializer=( + 'ConstantFill', { + 'value': 1.0, + 'dtype': core.DataType.FLOAT + } + ), + optimizer=self.model.NoOptim, + ) + + self.task_gamma = self.create_param( + param_name=('%s_task_gamma' % self.name), + shape=[1], + initializer=( + 'ConstantFill', { + 'value': task_gamma, + 'dtype': core.DataType.FLOAT + } + ), + optimizer=self.model.NoOptim, + ) + + self.task_gamma_lb = self.create_param( + param_name=('%s_task_gamma_lb' % self.name), + shape=[1], + initializer=( + 'ConstantFill', { + 'value': task_gamma_lb, + 'dtype': core.DataType.FLOAT + } + ), + optimizer=self.model.NoOptim, + ) def init_weight(self, jsd_weight, homotopy_weighting): if homotopy_weighting: @@ -197,6 +237,21 @@ def add_ops(self, net): [xent, focal_factor], net.NextScopedBlob("focallossxent") ) + if self.apply_exp_decay: + net.Mul( + [self.task_gamma_cur, self.task_gamma], + self.task_gamma_cur + ) + + task_gamma_multiplier = net.Max( + [self.task_gamma_cur, self.task_gamma_lb], + net.NextScopedBlob("task_gamma_cur_multiplier") + ) + + xent = net.Mul( + [xent, task_gamma_multiplier], net.NextScopedBlob("expdecayxent") + ) + # fuse with JSD if self.jsd_fuse: jsd = net.BernoulliJSD( diff --git a/caffe2/python/regularizer.py b/caffe2/python/regularizer.py index 5214a2e016310..563033c509104 100644 --- a/caffe2/python/regularizer.py +++ b/caffe2/python/regularizer.py @@ -42,6 +42,9 @@ def _run_on_loss(self, net, param_init_net, param, grad=None): def _run_after_optimizer(self, net, param_init_net, param, grad): return None + def _feature_grouping(self, param, net): + return None + def _ensure_clipped( self, net, @@ -84,6 +87,52 @@ def _run_on_loss(self, net, param_init_net, param, grad=None): net.Scale([output_blob], [output_blob], scale=self.reg_lambda) return output_blob +class FCInputLpNorm(Regularizer): + def __init__(self, reg_lambda, p_value=0.5): + super(FCInputLpNorm, self).__init__() + assert reg_lambda >= 0, "factor ahead of regularization should be 0 or positive" + assert p_value >= 0, "p_value factor should be 0 or positive" + self.p_value = p_value + self.reg_lambda = reg_lambda + + def _feature_grouping(self, param, net): + # Possible alternative grouping method via summing over absolute values + # Compute l2norm over feature weights + # pow( sum_i { pow(theda_i, 2) } , 0.5) + param_mul = net.Mul([param, param], [net.NextScopedBlob("param_mul")]) + param_reduced = net.ReduceFrontSum( + [param_mul], [net.NextScopedBlob("param_reduced")] + ) + grouped_feature_weight_vec = net.Pow( + [param_reduced], + [net.NextScopedBlob("grouped_feature_weight_vec")], + exponent=0.5, + ) + + return grouped_feature_weight_vec + + def _run_on_loss(self, net, param_init_net, param, grad=None): + # TODO: the second dim (num of input nodes) of param is after feature preproc, + # and does not correspond to the original num of dense features. + # In the future, will want to create a util to reduce the input dim of param to + # match the num of dense features. + + output_blob = net.NextScopedBlob(param + "_dense_feature_regularization") + grouped_feature_weight_vec = self._feature_grouping(param, net) + + # Compute Lpnorm over l2norm: + # pow( sum_i { pow(theda_i, p) } , 1/p) + lp_vec_raised = net.Pow( + [grouped_feature_weight_vec], [net.NextScopedBlob("lp_vec_raised")], exponent=self.p_value + ) + lp_vec_summed = net.ReduceFrontSum( + [lp_vec_raised], [net.NextScopedBlob("lp_vec_summed")] + ) + lp_vec = net.Pow( + [lp_vec_summed], [net.NextScopedBlob("lp_vec")], exponent=(1 / self.p_value) + ) + net.Scale([lp_vec], [output_blob], scale=self.reg_lambda) + return output_blob class L1NormTrimmed(Regularizer): """ diff --git a/caffe2/quantization/server/caffe2_dnnlowp_utils.cc b/caffe2/quantization/server/caffe2_dnnlowp_utils.cc index 4e0e4cecbd40b..af671a0f596eb 100644 --- a/caffe2/quantization/server/caffe2_dnnlowp_utils.cc +++ b/caffe2/quantization/server/caffe2_dnnlowp_utils.cc @@ -71,7 +71,20 @@ TensorQuantizationParams GetInputTensorQuantizationParamsOf( float min, max; fbgemm::FindMinMax( tensor->template data(), &min, &max, tensor->numel()); - + auto activation_quantization_kind = qfactory->GetActivationKind(); + if (activation_quantization_kind != + QuantizationFactory::QuantizationKind::MIN_MAX_QUANTIZATION) { + LOG(WARNING) + << "DNNLOWP dynamic int8 FC uses min_max as the only activation_quantization kind. Qparams will be assigned based on min_max regardless of activation_quantization_kind args."; + } + if (is_weight) { + auto weight_quantization_kind = qfactory->GetWeightKind(); + if (weight_quantization_kind != + QuantizationFactory::QuantizationKind::MIN_MAX_QUANTIZATION) { + LOG(WARNING) + << "DNNLOWP dynamic int8 FC weight is not constant, assigning qparams to weight based on min_max, regardless of weight_quantization_kind args."; + } + } return qfactory->ChooseQuantizationParams(min, max, is_weight); } } diff --git a/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.cc b/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.cc new file mode 100644 index 0000000000000..5e20d62365f6e --- /dev/null +++ b/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.cc @@ -0,0 +1,64 @@ +#include "caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.h" + +namespace caffe2 { + +template +bool ResizeNearest3DDNNLowPOp::RunOnDevice() { + using namespace dnnlowp; + + this->ParseDNNLowPOperatorArguments_(); + + // Choose quantization params + in_qparams_[0] = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get()); + + const auto& X = InputTensorCPU_(0); + auto* Y = OutputTensorCPU_(0); + + CAFFE_ENFORCE_EQ(X.ndim(), 5); + const int N = X.dim32(0); + // input frames + const int IF = X.dim32(1); + const int IH = X.dim32(2); + const int IW = X.dim32(3); + const int C = X.dim32(4); + const int OF = IF * temporal_scale_; + const int OH = IH * height_scale_; + const int OW = IW * width_scale_; + + vector buffer_shape{N, OF, OH, OW, C}; + Y->Resize(buffer_shape); + const T* X_data = X.template data(); + T* Y_data = Y->template mutable_data(); + +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (int n = 0; n < N; ++n) { + for (int t = 0; t < OF; ++t) { + const int in_f = std::min((int)(t / temporal_scale_), (IF - 1)); + for (int y = 0; y < OH; ++y) { + const int in_y = std::min((int)(y / height_scale_), (IH - 1)); + for (int x = 0; x < OW; ++x) { + const int in_x = std::min((int)(x / width_scale_), (IW - 1)); + std::memcpy( + &Y_data[((((n * OF) + t) * OH + y) * OW + x) * C], + &X_data[((((n * IF) + in_f) * IH + in_y) * IW + in_x) * C], + C * sizeof(T)); + } + } + } + } + // Even if there is a pre-chosen quantization parameters for the output, + // it is ignored because resize nearest output quantization should be same + // as the input. + PropagateOutputTensorQuantizationParams(this, 0, in_qparams_[0]); + + return true; +} + +REGISTER_CPU_OPERATOR_WITH_ENGINE( + Int8ResizeNearest3D, + DNNLOWP, + ResizeNearest3DDNNLowPOp); + +} // namespace caffe2 diff --git a/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.h b/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.h new file mode 100644 index 0000000000000..e7e09d94abe59 --- /dev/null +++ b/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op.h @@ -0,0 +1,41 @@ +#pragma once + +#include "caffe2/operators/resize_3d_op.h" +#include "caffe2/quantization/server/dnnlowp_op.h" + +namespace caffe2 { + +using ResizeNearest3DFP32Op = ResizeNearest3DOp; + +template +class ResizeNearest3DDNNLowPOp final + : public DNNLowPOp { + public: + USE_OPERATOR_FUNCTIONS(CPUContext); + USE_DNNLOWP_OPERATOR_BASE_FUNCTIONS(T, ResizeNearest3DFP32Op); + + ResizeNearest3DDNNLowPOp(const OperatorDef& operator_def, Workspace* ws) + : BaseType(operator_def, ws), + temporal_scale_( + this->template GetSingleArgument("temporal_scale", 1)), + width_scale_(this->template GetSingleArgument("width_scale", 1)), + height_scale_( + this->template GetSingleArgument("height_scale", 1)) { + CAFFE_ENFORCE_GT(temporal_scale_, 0); + CAFFE_ENFORCE_GT(width_scale_, 0); + CAFFE_ENFORCE_GT(height_scale_, 0); + + const auto& order = StringToStorageOrder( + this->template GetSingleArgument("order", "NHWC")); + CAFFE_ENFORCE_EQ(order, StorageOrder::NHWC); + } + + bool RunOnDevice() override; + + private: + float temporal_scale_; + float width_scale_; + float height_scale_; +}; + +} // namespace caffe2 diff --git a/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op_test.py b/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op_test.py new file mode 100644 index 0000000000000..733e76cd58acd --- /dev/null +++ b/caffe2/quantization/server/resize_nearest_3d_dnnlowp_op_test.py @@ -0,0 +1,62 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np +from caffe2.python import core, dyndep, workspace +from hypothesis import given + + +dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops") +workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"]) + + +class DNNLowPResizeNearest3DOpTest(hu.HypothesisTestCase): + @given( + N=st.integers(1, 3), + T=st.integers(1, 16), + H=st.integers(10, 300), + W=st.integers(10, 300), + C=st.integers(1, 32), + scale_t=st.floats(0.25, 4.0) | st.just(2.0), + scale_w=st.floats(0.25, 4.0) | st.just(2.0), + scale_h=st.floats(0.25, 4.0) | st.just(2.0), + **hu.gcs_cpu_only + ) + def test_resize_nearest(self, N, T, H, W, C, scale_t, scale_w, scale_h, gc, dc): + X = np.round(np.random.rand(N, T, H, W, C) * 255).astype(np.float32) + + quantize = core.CreateOperator("Quantize", ["X"], ["X_q"], engine="DNNLOWP") + resize_nearest = core.CreateOperator( + "Int8ResizeNearest3D", + ["X_q"], + ["Y_q"], + temporal_scale=scale_t, + width_scale=scale_w, + height_scale=scale_h, + engine="DNNLOWP", + ) + + net = core.Net("test_net") + net.Proto().op.extend([quantize, resize_nearest]) + + workspace.FeedBlob("X", X) + workspace.RunNetOnce(net) + X_q = workspace.FetchInt8Blob("X_q").data + Y_q = workspace.FetchInt8Blob("Y_q").data + + def resize_nearest_ref(X): + outT = np.int32(T * scale_t) + outH = np.int32(H * scale_h) + outW = np.int32(W * scale_w) + outT_idxs, outH_idxs, outW_idxs = np.meshgrid( + np.arange(outT), np.arange(outH), np.arange(outW), indexing="ij" + ) + inT_idxs = np.minimum(outT_idxs / scale_t, T - 1).astype(np.int32) + inH_idxs = np.minimum(outH_idxs / scale_h, H - 1).astype(np.int32) + inW_idxs = np.minimum(outW_idxs / scale_w, W - 1).astype(np.int32) + Y = X[:, inT_idxs, inH_idxs, inW_idxs, :] + return Y + + Y_q_ref = resize_nearest_ref(X_q) + np.testing.assert_allclose(Y_q, Y_q_ref) diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 387304cf74679..f8fdcc8be71e6 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -171,29 +171,24 @@ if (INTERN_BUILD_ATEN_OPS) message(STATUS ${generated_cpp}) message(FATAL_ERROR "Failed to get generated_cpp list") endif() + # FIXME: the file/variable name lists cpp, but these list both cpp and .h files file(READ ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_cpp.txt generated_cpp) file(READ ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_cpp.txt-cuda cuda_generated_cpp) + file(READ ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_cpp.txt-core core_generated_cpp) file(GLOB_RECURSE all_templates "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/templates/*") - # these are files that are generated by the script and checked in -- the script checks - # that they are equivalent so it must be a dependency of the script - set(core_gen_checked_inputs - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core/TensorBody.h - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core/TensorMethods.h - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core/OpsAlreadyMovedToC10.cpp) - file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/aten/src/ATen) - file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/aten/src/ATen/core_tmp) + file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/aten/src/ATen/core) - add_custom_command(OUTPUT ${generated_cpp} ${cuda_generated_cpp} + add_custom_command(OUTPUT ${generated_cpp} ${cuda_generated_cpp} ${core_generated_cpp} COMMAND ${GEN_COMMAND} - DEPENDS ${all_python} ${all_templates} ${cwrap_files} ${core_gen_checked_inputs}) + DEPENDS ${all_python} ${all_templates} ${cwrap_files}) # Generated headers used from a CUDA (.cu) file are # not tracked correctly in CMake. We make the libATen.so depend explicitly # on building the generated ATen files to workaround. - add_custom_target(ATEN_CPU_FILES_GEN_TARGET DEPENDS ${generated_cpp}) + add_custom_target(ATEN_CPU_FILES_GEN_TARGET DEPENDS ${generated_cpp} ${core_generated_cpp}) add_custom_target(ATEN_CUDA_FILES_GEN_TARGET DEPENDS ${cuda_generated_cpp}) add_library(ATEN_CPU_FILES_GEN_LIB INTERFACE) add_library(ATEN_CUDA_FILES_GEN_LIB INTERFACE) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index a8e9769536cb8..e6855f345b190 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -948,6 +948,11 @@ if(USE_ROCM) message(INFO "Compiling with HIP for AMD.") caffe2_update_option(USE_ROCM ON) + if (USE_NCCL AND NOT USE_SYSTEM_NCCL) + message(INFO "Forcing USE_SYSTEM_NCCL to ON since it's required by using RCCL") + caffe2_update_option(USE_SYSTEM_NCCL ON) + endif() + list(APPEND HIP_CXX_FLAGS -fPIC) list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_HCC__=1) list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1) @@ -978,12 +983,12 @@ if(USE_ROCM) endforeach() set(Caffe2_HIP_INCLUDE - ${thrust_INCLUDE_DIRS} ${hipcub_INCLUDE_DIRS} ${rocprim_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} $ ${Caffe2_HIP_INCLUDE}) + ${thrust_INCLUDE_DIRS} ${hipcub_INCLUDE_DIRS} ${rocprim_INCLUDE_DIRS} ${miopen_INCLUDE_DIRS} ${rocblas_INCLUDE_DIRS} ${rocrand_INCLUDE_DIRS} ${hiprand_INCLUDE_DIRS} ${roctracer_INCLUDE_DIRS} ${hip_INCLUDE_DIRS} ${hcc_INCLUDE_DIRS} ${hsa_INCLUDE_DIRS} $ ${Caffe2_HIP_INCLUDE}) # This is needed for library added by hip_add_library (same for hip_add_executable) hip_include_directories(${Caffe2_HIP_INCLUDE}) set(Caffe2_HIP_DEPENDENCY_LIBS - ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB}) + ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${PYTORCH_RCCL_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB} ${ROCM_ROCTX_LIB}) # Note [rocblas & rocfft cmake bug] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1006,17 +1011,20 @@ endif() # ---[ NCCL if(USE_NCCL) - if(NOT USE_CUDA) + if(NOT (USE_CUDA OR USE_ROCM)) message(WARNING - "Not using CUDA, so disabling NCCL. Suppress this warning with " + "Not using CUDA/ROCM, so disabling USE_NCCL. Suppress this warning with " "-DUSE_NCCL=OFF.") caffe2_update_option(USE_NCCL OFF) elseif(NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") message(WARNING "NCCL is currently only supported under Linux.") caffe2_update_option(USE_NCCL OFF) - else() + elseif(USE_CUDA) include(${CMAKE_CURRENT_LIST_DIR}/External/nccl.cmake) list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS __caffe2_nccl) + elseif(USE_ROCM) + include(${CMAKE_CURRENT_LIST_DIR}/External/rccl.cmake) + list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS __caffe2_nccl) endif() endif() @@ -1058,7 +1066,7 @@ if(USE_GLOO) # Add explicit dependency since NCCL is built from third_party. # Without dependency, make -jN with N>1 can fail if the NCCL build # hasn't finished when CUDA targets are linked. - if(USE_NCCL) + if(USE_NCCL AND NOT USE_ROCM) add_dependencies(gloo_cuda nccl_external) endif() # Pick the right dependency depending on USE_CUDA diff --git a/cmake/External/rccl.cmake b/cmake/External/rccl.cmake new file mode 100644 index 0000000000000..bc9265d4a3cc2 --- /dev/null +++ b/cmake/External/rccl.cmake @@ -0,0 +1,18 @@ +if (NOT __NCCL_INCLUDED) + set(__NCCL_INCLUDED TRUE) + + if (USE_SYSTEM_NCCL) + # NCCL_ROOT, NCCL_LIB_DIR, NCCL_INCLUDE_DIR will be accounted in the following line. + find_package(RCCL REQUIRED) + if (RCCL_FOUND) + message (STATUS "RCCL Found!") + add_library(__caffe2_nccl INTERFACE) + target_link_libraries(__caffe2_nccl INTERFACE ${PYTORCH_RCCL_LIBRARIES}) + target_include_directories(__caffe2_nccl INTERFACE ${RCCL_INCLUDE_DIRS}) + else() + message (STATUS "RCCL NOT Found!") + endif() + else() + message (STATUS "USE_SYSTEM_NCCL=OFF is not supported yet when using RCCL") + endif() +endif() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 414b2be3afbae..fd5e97d3e1e26 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -80,6 +80,13 @@ ELSE() SET(MIOPEN_PATH $ENV{MIOPEN_PATH}) ENDIF() +# RCCL_PATH +IF(NOT DEFINED ENV{RCCL_PATH}) + SET(RCCL_PATH ${ROCM_PATH}/rccl) +ELSE() + SET(RCCL_PATH $ENV{RCCL_PATH}) +ENDIF() + # ROCPRIM_PATH IF(NOT DEFINED ENV{ROCPRIM_PATH}) SET(ROCPRIM_PATH ${ROCM_PATH}/rocprim) @@ -101,8 +108,15 @@ ELSE() SET(ROCTHRUST_PATH $ENV{ROCTHRUST_PATH}) ENDIF() +# ROCTRACER_PATH +IF(NOT DEFINED ENV{ROCTRACER_PATH}) + SET(ROCTRACER_PATH ${ROCM_PATH}/roctracer) +ELSE() + SET(ROCTRACER_PATH $ENV{ROCTRACER_PATH}) +ENDIF() + IF(NOT DEFINED ENV{PYTORCH_ROCM_ARCH}) - SET(PYTORCH_ROCM_ARCH gfx803;gfx900;gfx906) + SET(PYTORCH_ROCM_ARCH gfx803;gfx900;gfx906;gfx908) ELSE() SET(PYTORCH_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}) ENDIF() @@ -145,6 +159,7 @@ IF(HIP_FOUND) set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen) set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft) set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse) + set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl) set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim) set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub) set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust) @@ -155,6 +170,7 @@ IF(HIP_FOUND) find_package_and_print_version(miopen REQUIRED) find_package_and_print_version(rocfft REQUIRED) find_package_and_print_version(hipsparse REQUIRED) + find_package_and_print_version(rccl) find_package_and_print_version(rocprim REQUIRED) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) @@ -167,9 +183,14 @@ IF(HIP_FOUND) # TODO: miopen_LIBRARIES should return fullpath to the library file, # however currently it's just the lib name FIND_LIBRARY(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib) + # TODO: rccl_LIBRARIES should return fullpath to the library file, + # however currently it's just the lib name + FIND_LIBRARY(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib) # hiprtc is part of HIP FIND_LIBRARY(ROCM_HIPRTC_LIB hiprtc HINTS ${HIP_PATH}/lib) - + # roctx is part of roctracer + FIND_LIBRARY(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib) + set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include) # Necessary includes for building PyTorch since we include HIP headers that depend on hcc/hsa headers. set(hcc_INCLUDE_DIRS ${HCC_PATH}/include) diff --git a/docs/cpp/source/Doxyfile b/docs/cpp/source/Doxyfile index 8290c9c71d1b1..19661452f0624 100644 --- a/docs/cpp/source/Doxyfile +++ b/docs/cpp/source/Doxyfile @@ -33,7 +33,6 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../aten/src/ATen/Backend.h \ ../../../aten/src/ATen/core/ivalue.h \ ../../../aten/src/ATen/core/ScalarType.h \ - ../../../aten/src/ATen/core/Tensor.h \ ../../../aten/src/ATen/cuda/CUDAContext.h \ ../../../aten/src/ATen/cudnn/Descriptors.h \ ../../../aten/src/ATen/cudnn/Handles.h \ @@ -44,6 +43,7 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../aten/src/ATen/mkl/Descriptors.h \ ../../../aten/src/ATen/Scalar.h \ ../../../aten/src/ATen/TensorOptions.h \ + ../../../aten/src/ATen/core/Tensor.h \ ../../../build/aten/src/ATen/Functions.h \ ../../../c10/core/Device.h \ ../../../c10/core/DeviceType.h \ diff --git a/docs/cpp/source/index.rst b/docs/cpp/source/index.rst index f620ff4208c82..b7d914c8517b2 100644 --- a/docs/cpp/source/index.rst +++ b/docs/cpp/source/index.rst @@ -53,7 +53,7 @@ ATen ``Tensor`` class with capabilities concerning automatic differentiation. The autograd system records operations on tensors to form an *autograd graph*. Calling ``backwards()`` on a leaf variable in this graph performs reverse mode differentiation through the network of functions and tensors spanning the -autograd graph, ultimately yieldings gradients. The following example provides +autograd graph, ultimately yielding gradients. The following example provides a taste of this interface: .. code-block:: cpp @@ -68,7 +68,7 @@ a taste of this interface: The ``at::Tensor`` class in ATen is not differentiable by default. To add the differentiability of tensors the autograd API provides, you must use tensor -factory functions from the `torch::` namespace instead of the `at` namespace. +factory functions from the `torch::` namespace instead of the `at::` namespace. For example, while a tensor created with `at::ones` will not be differentiable, a tensor created with `torch::ones` will be. diff --git a/docs/source/jit.rst b/docs/source/jit.rst index beabafd55b324..f962b83c0e8ac 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -449,9 +449,12 @@ Example (a type mismatch) if x: ~~~~~... <--- HERE r = torch.rand(1) + else: + and was used here: else: r = 4 return r + ~ <--- HERE ... diff --git a/docs/source/tensorboard.rst b/docs/source/tensorboard.rst index e2a34ba03e81e..d3205e3ba5892 100644 --- a/docs/source/tensorboard.rst +++ b/docs/source/tensorboard.rst @@ -91,5 +91,6 @@ Expected result: .. automethod:: add_pr_curve .. automethod:: add_custom_scalars .. automethod:: add_mesh + .. automethod:: add_hparams .. automethod:: flush .. automethod:: close diff --git a/ios/README.md b/ios/README.md index 1dc4f65d8777d..cd2d31afa1247 100644 --- a/ios/README.md +++ b/ios/README.md @@ -17,7 +17,7 @@ For Objective-C developers, simply import the umbrella header #import ``` -For Swift developers, you need to create an Objective-C class as a bridge to call the C++ APIs. We highly recommend you to follow the [Image Classification](https://github.com/pytorch/examples) demo where you can find out how C++, Objective-C and Swift work together. +For Swift developers, you need to create an Objective-C class as a bridge to call the C++ APIs. We highly recommend you to follow the [Image Classification](https://github.com/pytorch/ios-demo-app/tree/master/PyTorchDemo) demo where you can find out how C++, Objective-C and Swift work together. ### Disable Bitcode @@ -25,4 +25,4 @@ Since PyTorch is not yet built with bitcode support, you need to disable bitcode ## LICENSE -PyTorch is BSD-style licensed, as found in the LICENSE file. \ No newline at end of file +PyTorch is BSD-style licensed, as found in the LICENSE file. diff --git a/scripts/xcode_ios_x86_build.rb b/scripts/xcode_build.rb similarity index 80% rename from scripts/xcode_ios_x86_build.rb rename to scripts/xcode_build.rb index c786d9a17f743..20cb2d9df42af 100644 --- a/scripts/xcode_ios_x86_build.rb +++ b/scripts/xcode_build.rb @@ -10,6 +10,9 @@ opts.on('-x', '--xcodeproj ', 'path to the XCode project file') { |value| options[:xcodeproj] = value } + opts.on('-p', '--platform ', 'iOS platform for the current build') { |value| + options[:platform] = value + } end.parse! puts options.inspect @@ -36,6 +39,7 @@ config.build_settings['HEADER_SEARCH_PATHS'] = header_search_path config.build_settings['LIBRARY_SEARCH_PATHS'] = libraries_search_path config.build_settings['OTHER_LINKER_FLAGS'] = other_linker_flags + config.build_settings['ENABLE_BITCODE'] = 'No' end # link static libraries @@ -48,5 +52,15 @@ end project.save +sdk = nil +if options[:platform] == 'SIMULATOR' + sdk = 'iphonesimulator' +elsif options[:platform] == 'OS' + sdk = 'iphoneos' +else + puts "unsupported platform #{options[:platform]}" + exit(false) +end + # run xcodebuild -exec "xcodebuild clean build -project #{xcodeproj_path} -target #{target.name} -sdk iphonesimulator -configuration Release" +exec "xcodebuild clean build -project #{xcodeproj_path} -target #{target.name} -sdk #{sdk} -configuration Release" diff --git a/setup.py b/setup.py index 0e457541ebfe3..7ada2774a1aa9 100644 --- a/setup.py +++ b/setup.py @@ -184,7 +184,7 @@ from tools.setup_helpers.env import (IS_WINDOWS, IS_DARWIN, check_env_flag, build_type) from tools.setup_helpers.cmake import CMake -from tools.setup_helpers.cuda import CUDA_HOME, CUDA_VERSION +from tools.setup_helpers.cuda import CUDA_HOME try: FileNotFoundError @@ -257,7 +257,7 @@ def report(*args): # Version, create_version_file, and package_name ################################################################################ package_name = os.getenv('TORCH_PACKAGE_NAME', 'torch') -version = '1.3.0a0' +version = open('version.txt', 'r').read().strip() sha = 'Unknown' try: @@ -280,15 +280,6 @@ def report(*args): # all the work we need to do _before_ setup runs def build_deps(): report('-- Building version ' + version) - version_path = os.path.join(cwd, 'torch', 'version.py') - with open(version_path, 'w') as f: - f.write("__version__ = '{}'\n".format(version)) - # NB: This is not 100% accurate, because you could have built the - # library code with DEBUG, but csrc without DEBUG (in which case - # this would claim to be a release build when it's not.) - f.write("debug = {}\n".format(repr(build_type.is_debug()))) - f.write("cuda = {}\n".format(repr(CUDA_VERSION))) - f.write("git_version = {}\n".format(repr(sha))) def check_file(f): if not os.path.exists(f): @@ -318,6 +309,18 @@ def check_file(f): rerun_cmake=RERUN_CMAKE, cmake_only=CMAKE_ONLY, cmake=cmake) + + version_path = os.path.join(cwd, 'torch', 'version.py') + with open(version_path, 'w') as f: + f.write("__version__ = '{}'\n".format(version)) + # NB: This is not 100% accurate, because you could have built the + # library code with DEBUG, but csrc without DEBUG (in which case + # this would claim to be a release build when it's not.) + f.write("debug = {}\n".format(repr(build_type.is_debug()))) + cmake_cache_vars = defaultdict(lambda: None, cmake.get_cmake_cache_variables()) + f.write("cuda = {}\n".format(repr(cmake_cache_vars['CUDA_VERSION']))) + f.write("git_version = {}\n".format(repr(sha))) + if CMAKE_ONLY: report('Finished running cmake. Run "ccmake build" or ' '"cmake-gui build" to adjust build options and ' @@ -828,6 +831,7 @@ def print_box(msg): 'include/torch/csrc/api/include/torch/nn/modules/*.h', 'include/torch/csrc/api/include/torch/nn/modules/container/*.h', 'include/torch/csrc/api/include/torch/nn/parallel/*.h', + 'include/torch/csrc/api/include/torch/nn/utils/*.h', 'include/torch/csrc/api/include/torch/optim/*.h', 'include/torch/csrc/api/include/torch/serialize/*.h', 'include/torch/csrc/autograd/*.h', diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index f18438244334d..e1115f012e979 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -12,6 +12,7 @@ ('quantize', datetime.date(2019, 10, 1)), ('q_per_channel_axis', datetime.date(2019, 10, 1)), ('fbgemm_is_cpu_supported', datetime.date(2019, 10, 1)), + ('c10_experimental', datetime.date(2020, 1, 1)), ] diff --git a/test/common_device_type.py b/test/common_device_type.py index 101a303fa05ca..24a70410918b1 100644 --- a/test/common_device_type.py +++ b/test/common_device_type.py @@ -148,8 +148,8 @@ def _get_dtypes(cls, test): # Creates device-specific tests. @classmethod - def instantiate_test(cls, test): - test_name = test.__name__ + "_" + cls.device_type + def instantiate_test(cls, name, test): + test_name = name + "_" + cls.device_type dtypes = cls._get_dtypes(test) if dtypes is None: # Test has no dtype variants @@ -196,11 +196,10 @@ def get_all_devices(cls): primary_device_idx = int(cls.get_primary_device().split(':')[1]) num_devices = torch.cuda.device_count() - devices = [cls.get_primary_device()] + prim_device = cls.get_primary_device() cuda_str = 'cuda:{0}' non_primary_devices = [cuda_str.format(idx) for idx in range(num_devices) if idx != primary_device_idx] - devices.extend(non_primary_devices) - return devices + return [prim_device] + non_primary_devices @classmethod def setUpClass(cls): @@ -255,7 +254,6 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None): for name in generic_members: if name in generic_tests: # Instantiates test member - # Requires tests be a function for Python2 compat # (In Python2 tests are type checked methods wrapping functions) test = getattr(generic_test_class, name) @@ -264,7 +262,7 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None): assert inspect.isfunction(test), "Couldn't extract function from '{0}'".format(name) # Instantiates the device-specific tests - device_type_test_class.instantiate_test(test) + device_type_test_class.instantiate_test(name, test) else: # Ports non-test member assert not hasattr(device_type_test_class, name), "Redefinition of non-test member {0}".format(name) @@ -442,7 +440,7 @@ def wrap_fn(self, device, *args, **kwargs): if self.no_cudnn: reason = "cuDNN not available" raise unittest.SkipTest(reason) - if self.cudnn_version < version: + if self.cudnn_version is None or self.cudnn_version < version: reason = "cuDNN version {0} is available but {1} required".format(self.cudnn_version, version) raise unittest.SkipTest(reason) diff --git a/test/common_distributed.py b/test/common_distributed.py index cecd1ae6557f9..dd5cd715fd874 100644 --- a/test/common_distributed.py +++ b/test/common_distributed.py @@ -64,6 +64,18 @@ def requires_gloo(): "c10d was not compiled with the Gloo backend", ) +def requires_nccl_version(version, msg): + if not c10d.is_nccl_available(): + return unittest.skip( + "c10d was not compiled with the NCCL backend", + ) + else: + return unittest.skipIf( + torch.cuda.nccl.version() < version, + "Requires NCCL version greater than or equal to: {}, found: {}, reason: {}".format( + version, + torch.cuda.nccl.version(), msg), + ) def requires_nccl(): return unittest.skipUnless( diff --git a/test/common_nn.py b/test/common_nn.py index e5e35a2e9b27e..694cf5960075e 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -295,8 +295,8 @@ def bceloss_no_reduce_test(): lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), - check_gradgrad=False, - pickle=False) + pickle=False, + precision=7e-4) def bceloss_no_reduce_scalar_test(): @@ -307,7 +307,6 @@ def bceloss_no_reduce_scalar_test(): lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')), input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()), - check_gradgrad=False, pickle=False) @@ -321,8 +320,8 @@ def bceloss_weights_no_reduce_test(): weight=weights.type_as(i), reduction='none')), input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, - check_gradgrad=False, - pickle=False + pickle=False, + precision=3e-4 ) @@ -336,7 +335,6 @@ def bceloss_weights_no_reduce_scalar_test(): weight=weights.type_as(i), reduction='none')), input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2), reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, - check_gradgrad=False, pickle=False ) diff --git a/test/common_quantization.py b/test/common_quantization.py index 08bd68f9869fe..d98b1af8758d8 100644 --- a/test/common_quantization.py +++ b/test/common_quantization.py @@ -15,7 +15,8 @@ from common_utils import TestCase from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ - propagate_qconfig_, convert, DEFAULT_DYNAMIC_MODULE_MAPPING + propagate_qconfig_, convert +from torch.quantization.default_mappings import DEFAULT_DYNAMIC_MODULE_MAPPING def test_only_eval_fn(model, calib_data): r""" @@ -595,3 +596,28 @@ def forward(self, x): out = out.view(-1, 3 * 2 * 2) out = self.fc(out) return out + +"""Model to make sure that the observers are not inserted into custom modules. +""" +class ModelWithNoQconfigPropagation(nn.Module): + class ListOutModule(nn.Module): + def __init__(self): + super(ModelWithNoQconfigPropagation.ListOutModule, self).__init__() + + def forward(self, x): + # returns a list of tensors, not supported by observers + return [x] + + def __init__(self): + super(ModelWithNoQconfigPropagation, self).__init__() + self.fc1 = nn.Linear(5, 5).to(dtype=torch.float) + self.quant = QuantStub() + self.dequant = DeQuantStub() + self.no_quant_module = self.ListOutModule() + + def forward(self, x): + x = self.quant(x) + x = self.fc1(x) + x = self.dequant(x) + x = self.no_quant_module(x) + return x diff --git a/test/common_utils.py b/test/common_utils.py index ef1342453670d..5be92588cbf91 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -42,7 +42,6 @@ import torch.backends.mkl -torch.set_default_tensor_type('torch.DoubleTensor') torch.backends.disable_global_flags() @@ -208,6 +207,26 @@ def wrapper(*args, **kwargs): return wrapper +def skipIfCompiledWithoutNumpy(fn): + # Even if the numpy module is present, if `USE_NUMPY=0` is used during the + # build, numpy tests will fail + numpy_support = TEST_NUMPY + if numpy_support: + try: + # The numpy module is present, verify that PyTorch is compiled with + # numpy support + torch.from_numpy(numpy.array([2, 2])) + except RuntimeError: + numpy_support = False + + @wraps(fn) + def wrapper(*args, **kwargs): + if not numpy_support: + raise unittest.SkipTest("PyTorch was compiled without numpy support") + else: + fn(*args, **kwargs) + return wrapper + def _test_function(fn, device): def run_test_function(self): @@ -278,6 +297,20 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +def default_floating_dtype(dtype): + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + old_type = torch.tensor(()).dtype + torch.set_default_dtype(dtype) + try: + return fn(*args, **kwargs) + finally: + torch.set_default_dtype(old_type) + + return wrapper + + return decorator def get_cpu_type(type_name): module, name = type_name.rsplit('.', 1) @@ -950,7 +983,6 @@ def runWithPytorchAPIUsageStderr(code): assertNotRegex = unittest.TestCase.assertNotRegexpMatches - def download_file(url, binary=True): if sys.version_info < (3,): from urlparse import urlsplit @@ -1014,9 +1046,9 @@ def prod_single_zero(dim_size): return result -def random_square_matrix_of_rank(l, rank): +def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'): assert rank <= l - A = torch.randn(l, l) + A = torch.randn(l, l, dtype=dtype, device=device) u, s, v = A.svd() for i in range(l): if i >= rank: @@ -1026,20 +1058,28 @@ def random_square_matrix_of_rank(l, rank): return u.mm(torch.diag(s)).mm(v.transpose(0, 1)) -def random_symmetric_matrix(l, *batches): - A = torch.randn(*(batches + (l, l))) +def random_symmetric_matrix(l, *batches, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) A = (A + A.transpose(-2, -1)).div_(2) return A -def random_symmetric_psd_matrix(l, *batches): - A = torch.randn(*(batches + (l, l))) +def random_symmetric_psd_matrix(l, *batches, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batches + (l, l)), dtype=dtype, device=device) return torch.matmul(A, A.transpose(-2, -1)) -def random_symmetric_pd_matrix(matrix_size, *batch_dims): - A = torch.randn(*(batch_dims + (matrix_size, matrix_size))) - return torch.matmul(A, A.transpose(-2, -1)) + torch.eye(matrix_size) * 1e-5 +def random_symmetric_pd_matrix(matrix_size, *batch_dims, **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') + A = torch.randn(*(batch_dims + (matrix_size, matrix_size)), + dtype=dtype, device=device) + return torch.matmul(A, A.transpose(-2, -1)) \ + + torch.eye(matrix_size, dtype=dtype, device=device) * 1e-5 def make_nonzero_det(A, sign=None, min_singular_value=0.1): @@ -1060,48 +1100,20 @@ def make_nonzero_det(A, sign=None, min_singular_value=0.1): return A -def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims, **kwargs): +def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims, + **kwargs): + dtype = kwargs.get('dtype', torch.double) + device = kwargs.get('device', 'cpu') silent = kwargs.get("silent", False) if silent and not torch._C.has_lapack: - return torch.ones(matrix_size, matrix_size) + return torch.ones(matrix_size, matrix_size, dtype=dtype, device=device) - A = torch.randn(batch_dims + (matrix_size, matrix_size)) + A = torch.randn(batch_dims + (matrix_size, matrix_size), dtype=dtype, device=device) u, _, v = A.svd() - s = torch.arange(1., matrix_size + 1).mul_(1.0 / (matrix_size + 1)).diag() + s = torch.arange(1., matrix_size + 1, dtype=dtype, device=device).mul_(1.0 / (matrix_size + 1)).diag() return u.matmul(s.expand(batch_dims + (matrix_size, matrix_size)).matmul(v.transpose(-2, -1))) -def lu_solve_test_helper(self, A_dims, b_dims, cast, pivot): - b = cast(torch.randn(*b_dims)) - A = cast(random_fullrank_matrix_distinct_singular_value(*A_dims)) - LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot) - self.assertEqual(info, torch.zeros_like(info)) - return b, A, LU_data, LU_pivots - - -def cholesky_solve_test_helper(A_dims, b_dims, cast, upper): - b = cast(torch.randn(*b_dims)) - A = cast(random_symmetric_pd_matrix(*A_dims)) - L = torch.cholesky(A, upper=upper) - return b, A, L - - -def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular): - triangle_function = torch.triu if upper else torch.tril - b = cast(torch.randn(*b_dims)) - A = cast(torch.randn(*A_dims)) - A_triangular = triangle_function(A) - if unitriangular: - A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) - return b, A_triangular - - -def solve_test_helper(A_dims, b_dims, cast): - b = cast(torch.randn(*b_dims)) - A = cast(random_fullrank_matrix_distinct_singular_value(*A_dims)) - return b, A - - def brute_pdist(inp, p=2): """Computes the same as torch.pdist using primitives""" n = inp.shape[-2] diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 2b933817f6cc7..8621a581e68fc 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -14,6 +14,7 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/module.cpp ${TORCH_API_TEST_DIR}/modulelist.cpp ${TORCH_API_TEST_DIR}/modules.cpp + ${TORCH_API_TEST_DIR}/nn_utils.cpp ${TORCH_API_TEST_DIR}/optim.cpp ${TORCH_API_TEST_DIR}/ordered_dict.cpp ${TORCH_API_TEST_DIR}/rnn.cpp diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 5d3f8226183f1..0be5fbffeef4f 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -15,7 +15,7 @@ TEST_F(FunctionalTest, MaxPool1d) { auto y = F::max_pool1d(x, MaxPool1dOptions(3).stride(2)); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1 ,2}))); + ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); } @@ -24,7 +24,7 @@ TEST_F(FunctionalTest, MaxPool2d) { auto y = F::max_pool2d(x, MaxPool2dOptions(3).stride(2)); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2 ,2}))); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); } @@ -67,7 +67,8 @@ TEST_F(FunctionalTest, AvgPool3d) { TEST_F(FunctionalTest, CosineSimilarity) { auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat); auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat); - auto output = F::cosine_similarity(input1, input2, CosineSimilarityOptions().dim(1)); + auto output = + F::cosine_similarity(input1, input2, CosineSimilarityOptions().dim(1)); auto expected = torch::tensor({0.8078, 0.8721}, torch::kFloat); ASSERT_TRUE(output.allclose(expected, 1e-04)); } @@ -75,7 +76,8 @@ TEST_F(FunctionalTest, CosineSimilarity) { TEST_F(FunctionalTest, PairwiseDistance) { auto input1 = torch::tensor({{1, 2, 3}, {4, 5, 6}}, torch::kFloat); auto input2 = torch::tensor({{1, 8, 3}, {2, 1, 6}}, torch::kFloat); - auto output = F::pairwise_distance(input1, input2, PairwiseDistanceOptions(1)); + auto output = + F::pairwise_distance(input1, input2, PairwiseDistanceOptions(1)); auto expected = torch::tensor({6, 6}, torch::kFloat); ASSERT_TRUE(output.allclose(expected)); } @@ -165,15 +167,18 @@ TEST_F(FunctionalTest, MaxUnpool1d) { auto y = F::max_unpool1d(x, indices, MaxUnpool1dOptions(3)); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 9})); x = torch::tensor({{{2, 4, 5}}}, torch::requires_grad()); indices = torch::tensor({{{1, 3, 4}}}, torch::kLong); - y = F::max_unpool1d(x, indices, MaxUnpool1dOptions(3), c10::IntArrayRef({1, 1, 9})); + y = F::max_unpool1d( + x, indices, MaxUnpool1dOptions(3), c10::IntArrayRef({1, 1, 9})); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); + ASSERT_TRUE(torch::allclose( + y, torch::tensor({{{0, 2, 0, 4, 5, 0, 0, 0, 0}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 9})); x = torch::tensor({{{2, 4, 5}}}, torch::requires_grad()); @@ -181,7 +186,8 @@ TEST_F(FunctionalTest, MaxUnpool1d) { y = F::max_unpool1d(x, indices, MaxUnpool1dOptions(3).stride(2).padding(1)); ASSERT_EQ(y.ndimension(), 3); - ASSERT_TRUE(torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat))); + ASSERT_TRUE( + torch::allclose(y, torch::tensor({{{0, 2, 0, 4, 5}}}, torch::kFloat))); ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 5})); } @@ -254,3 +260,103 @@ TEST_F(FunctionalTest, Hardshrink) { ASSERT_TRUE(torch::allclose(y, y_exp)); } } + +TEST_F(FunctionalTest, OneHot) { + { // Test #1 + auto x = torch::arange(0, 5, torch::kLong); + auto y = F::one_hot(x % 3); + auto expected = torch::tensor( + {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}, {1, 0, 0}, {0, 1, 0}}, torch::kLong); + + ASSERT_EQ(y.ndimension(), 2); + ASSERT_TRUE(torch::allclose(y, expected)); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({5, 3})); + } + + { // Test #2 + auto x = torch::arange(0, 5, torch::kLong); + auto y = F::one_hot(x % 3, 5); + auto expected = torch::tensor( + {{1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0}, + {0, 0, 1, 0, 0}, + {1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0}}, + torch::kLong); + + ASSERT_EQ(y.ndimension(), 2); + ASSERT_TRUE(torch::allclose(y, expected)); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({5, 5})); + } + + { // Test #3 + auto x = torch::arange(0, 6, torch::kLong); + auto y = F::one_hot(x.view(torch::IntArrayRef({3, 2})) % 3); + auto expected = torch::tensor( + {{{1, 0, 0}, {0, 1, 0}}, + {{0, 0, 1}, {1, 0, 0}}, + {{0, 1, 0}, {0, 0, 1}}}, + torch::kLong); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, expected)); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 2, 3})); + } +} + +TEST_F(FunctionalTest, Hardtanh) { + const auto size = 3; + for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) { + for (const auto max_val : {0.0, 0.42, 1.0, 4.2}) { + for (const auto inplace : {false, true}) { + auto x = torch::linspace(-10.0, 10.0, size * size * size); + x.resize_({size, size, size}); + auto y_exp = (x < min_val) * min_val + + ((x >= min_val) * (x <= max_val)) * x + + (x > max_val) * max_val; + auto y = F::hardtanh(x,HardtanhOptions().min_val(min_val) + .max_val(max_val).inplace(inplace)); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size})); + ASSERT_TRUE(torch::allclose(y, y_exp)); + if (inplace) { + ASSERT_TRUE(torch::allclose(x, y_exp)); + } + } + } + } +} + +TEST_F(FunctionalTest, LeakyReLU) { + const auto size = 3; + for (const auto negative_slope : {0.0, 0.42, 1.0}) { + for (const auto inplace : {false, true}) { + auto x = torch::linspace(-10.0, 10.0, size * size * size); + x.resize_({size, size, size}); + auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x; + auto y = F::leaky_relu(x, LeakyReLUOptions() + .negative_slope(negative_slope).inplace(inplace)); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size})); + ASSERT_TRUE(torch::allclose(y, y_exp)); + if (inplace) { + ASSERT_TRUE(torch::allclose(x, y_exp)); + } + } + } +} + +TEST_F(FunctionalTest, LogSigmoid) { + const auto size = 3; + LogSigmoid model; + auto x = torch::linspace(-10.0, 10.0, size * size * size); + x.resize_({size, size, size}); + auto y = F::logsigmoid(x); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size})); + auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x)))); + ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); +} diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 7b8bfe6e22eb6..715bb50f3ab63 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -1044,6 +1044,78 @@ TEST_F(ModulesTest, Hardshrink) { } } +TEST_F(ModulesTest, Hardtanh) { + const auto size = 3; + for (const auto min_val : {-4.2, -1.0, -0.42, 0.0}) { + for (const auto max_val : {0.42, 1.0, 4.2}) { + Hardtanh model {HardtanhOptions().min_val(min_val).max_val(max_val)}; + auto x = torch::linspace(-10.0, 10.0, size * size * size); + x.resize_({size, size, size}).set_requires_grad(true); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(s.ndimension(), 0); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size})); + auto y_exp = (x < min_val) * min_val + + ((x >= min_val) * (x <= max_val)) * x + + (x > max_val) * max_val; + ASSERT_TRUE(torch::allclose(y, y_exp)); + } + } +} + +TEST_F(ModulesTest, HardtanhMinValGEMaxVal) { + ASSERT_THROWS_WITH(Hardtanh{HardtanhOptions().min_val(0.42).max_val(0.42)}, + "max_val must be greater than min_val"); + ASSERT_THROWS_WITH(Hardtanh{HardtanhOptions().min_val(0.42).max_val(-0.42)}, + "max_val must be greater than min_val"); + + Hardtanh ht {HardtanhOptions().min_val(-0.42).max_val(0.42)}; + ht->options.min_val(0.42); + ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val"); + ht->options.max_val(-0.42); + ASSERT_THROWS_WITH(ht->reset(), "max_val must be greater than min_val"); +} + +TEST_F(ModulesTest, LeakyReLU) { + const auto size = 3; + for (const auto negative_slope : {0.0, 0.42, 1.0}) { + LeakyReLU model {LeakyReLUOptions().negative_slope(negative_slope)}; + auto x = torch::linspace(-10.0, 10.0, size * size * size); + x.resize_({size, size, size}).set_requires_grad(true); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(s.ndimension(), 0); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size})); + auto y_exp = (x < 0) * x * negative_slope + (x >= 0) * x; + ASSERT_TRUE(torch::allclose(y, y_exp)); + } +} + +TEST_F(ModulesTest, LogSigmoid) { + const auto size = 3; + LogSigmoid model; + auto x = torch::linspace(-10.0, 10.0, size * size * size); + x.resize_({size, size, size}).set_requires_grad(true); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(s.ndimension(), 0); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({size, size, size})); + auto y_exp = torch::log(torch::ones_like(x)/(torch::ones_like(x) + torch::exp(torch::neg(x)))); + ASSERT_TRUE(torch::allclose(y, y_exp, 1e-4, 1e-7)); +} + TEST_F(ModulesTest, PrettyPrintIdentity) { ASSERT_EQ(c10::str(Identity()), "torch::nn::Identity()"); } @@ -1278,3 +1350,23 @@ TEST_F(ModulesTest, PrettyPrintHardshrink) { ASSERT_EQ(c10::str(Hardshrink(HardshrinkOptions().lambda(42.42))), "torch::nn::Hardshrink(42.42)"); } + +TEST_F(ModulesTest, PrettyPrintHardtanh) { + ASSERT_EQ(c10::str(Hardtanh()), + "torch::nn::Hardtanh(min_val=-1, max_val=1)"); + ASSERT_EQ(c10::str(Hardtanh( + HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true))), + "torch::nn::Hardtanh(min_val=-42.42, max_val=0.42, inplace=true)"); +} + +TEST_F(ModulesTest, PrettyPrintLeakyReLU) { + ASSERT_EQ(c10::str(LeakyReLU()), + "torch::nn::LeakyReLU(negative_slope=0.01)"); + ASSERT_EQ(c10::str(LeakyReLU( + LeakyReLUOptions().negative_slope(0.42).inplace(true))), + "torch::nn::LeakyReLU(negative_slope=0.42, inplace=true)"); +} + +TEST_F(ModulesTest, PrettyPrintLogSigmoid) { + ASSERT_EQ(c10::str(LogSigmoid()), "torch::nn::LogSigmoid()"); +} diff --git a/test/cpp/api/nn_utils.cpp b/test/cpp/api/nn_utils.cpp new file mode 100644 index 0000000000000..fa0693ec3ec3f --- /dev/null +++ b/test/cpp/api/nn_utils.cpp @@ -0,0 +1,103 @@ +#include + +#include + +#include + +using namespace torch::nn; +using namespace torch::test; + +struct NNUtilsTest : torch::test::SeedingFixture {}; + +TEST_F(NNUtilsTest, ClipGradNorm) { + auto linear_layer = Linear(10, 10); + float max_norm = 2; + auto compute_norm = [linear_layer](float norm_type) -> float { + float total_norm = 0.0; + if (norm_type != std::numeric_limits::infinity()) { + for (const auto& p : linear_layer->parameters()) { + total_norm += + p.grad().data().abs().pow(norm_type).sum().item().toFloat(); + } + return std::pow(total_norm, 1.0 / norm_type); + } else { + for (const auto& p : linear_layer->parameters()) { + auto param_max = p.grad().data().abs().max().item().toFloat(); + if (param_max > total_norm) { + total_norm = param_max; + } + } + return total_norm; + } + }; + auto compare_scaling = + [linear_layer](const std::vector& grads) -> torch::Tensor { + std::vector p_scale; + for (int i = 0; i < grads.size(); i++) { + auto param = linear_layer->parameters()[i]; + auto grad = grads[i]; + p_scale.push_back(param.grad().data().div(grad).view(-1)); + } + auto scale = torch::cat(p_scale); + return scale; // need to assert std is 0. + }; + + std::vector grads = { + torch::arange(1.0, 101).view({10, 10}), + torch::ones(10).div(1000), + }; + std::vector norm_types = { + 0.5, + 1.5, + 2.0, + 4.0, + std::numeric_limits::infinity(), + }; + for (auto norm_type : norm_types) { + for (int i = 0; i < grads.size(); i++) { + linear_layer->parameters()[i].grad() = + grads[i].clone().view_as(linear_layer->parameters()[i].data()); + } + auto norm_before = compute_norm(norm_type); + auto layer_params = linear_layer->parameters(); + auto norm = utils::clip_grad_norm_(layer_params, max_norm, norm_type); + auto norm_after = compute_norm(norm_type); + ASSERT_FLOAT_EQ(norm, norm_before); + ASSERT_FLOAT_EQ(norm_after, max_norm); + ASSERT_LE(norm_after, max_norm); + auto scaled = compare_scaling(grads); + ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7); + } + // Small gradients should be lefted unchanged + grads = { + torch::rand({10, 10}).div(10000), + torch::ones(10).div(500), + }; + for (auto norm_type : norm_types) { + for (int i = 0; i < grads.size(); i++) { + linear_layer->parameters()[i].grad().data().copy_(grads[i]); + } + auto norm_before = compute_norm(norm_type); + auto layer_params = linear_layer->parameters(); + auto norm = utils::clip_grad_norm_(layer_params, max_norm, norm_type); + auto norm_after = compute_norm(norm_type); + ASSERT_FLOAT_EQ(norm, norm_before); + ASSERT_FLOAT_EQ(norm_before, norm_after); + ASSERT_LE(norm_after, max_norm); + auto scaled = compare_scaling(grads); + ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7); + ASSERT_EQ(scaled[0].item().toFloat(), 1); + } + // should accept a single tensor as input + auto p1 = torch::randn({10, 10}); + auto p2 = torch::randn({10, 10}); + auto g = torch::arange(1., 101).view({10, 10}); + p1.grad() = g.clone(); + p2.grad() = g.clone(); + for (const auto norm_type : norm_types) { + utils::clip_grad_norm_(p1, max_norm, norm_type); + std::vector params = {p2}; + utils::clip_grad_norm_(params, max_norm, norm_type); + ASSERT_TRUE(torch::allclose(p1.grad(), p2.grad())); + } +} diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp index 0a43bd855e73d..4a262cddccc0e 100644 --- a/test/cpp/api/tensor.cpp +++ b/test/cpp/api/tensor.cpp @@ -474,6 +474,13 @@ TEST(TensorTest, BackwardCreatesOnesGrad) { torch::ones_like(x))); } +TEST(TensorTest, BackwardNonScalarOutputs) { + auto x = torch::randn({5, 5}, torch::requires_grad()); + auto y = x * x; + ASSERT_THROWS_WITH(y.backward(), + "grad can be implicitly created only for scalar outputs"); +} + TEST(TensorTest, IsLeaf) { auto x = torch::tensor({5}, at::TensorOptions().requires_grad(true)); auto y = x * x; diff --git a/test/cpp/dist_autograd/test_dist_autograd.cpp b/test/cpp/dist_autograd/test_dist_autograd.cpp index 91bdbad3456f9..7af16e6a34628 100644 --- a/test/cpp/dist_autograd/test_dist_autograd.cpp +++ b/test/cpp/dist_autograd/test_dist_autograd.cpp @@ -1,10 +1,26 @@ #include #include +#include +#include #include +#include #include -TEST(DistAutogradTest, TestSendFunction) { +using namespace torch::distributed::autograd; +using namespace torch::distributed::rpc; + +class DistAutogradTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + autogradContainer_ = &DistAutogradContainer::init(0); + } + static DistAutogradContainer* autogradContainer_; +}; + +DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr; + +TEST_F(DistAutogradTest, TestSendFunction) { // Initialize input tensors requiring grad. auto options = at::TensorOptions().requires_grad(true); auto in1 = torch::ones({3, 3}, options); @@ -12,9 +28,12 @@ TEST(DistAutogradTest, TestSendFunction) { ASSERT_FALSE(in1.grad().defined()); ASSERT_FALSE(in2.grad().defined()); + autogradContainer_->newContext(); + DistAutogradContext& autogradContext = autogradContainer_->currentContext(); // Attach the send autograd function to tensors. - auto send_function = - torch::distributed::autograd::addSendRpcBackward({in1, in2}); + std::vector tensors = {in1, in2}; + addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors); + auto send_function = autogradContext.sendFunctions()[1]; ASSERT_NE(send_function, nullptr); // Build loss and attach it as input to send autograd function. @@ -33,14 +52,17 @@ TEST(DistAutogradTest, TestSendFunction) { ASSERT_TRUE(in2.grad().defined()); } -TEST(DistAutogradTest, TestSendFunctionInvalidInputs) { +TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) { auto options = at::TensorOptions().requires_grad(true); auto in1 = torch::ones({3, 3}, options); auto in2 = torch::ones({3, 3}, options); + autogradContainer_->newContext(); + DistAutogradContext& autogradContext = autogradContainer_->currentContext(); // Attach the send autograd function to tensors. - auto send_function = - torch::distributed::autograd::addSendRpcBackward({in1, in2}); + std::vector tensors = {in1, in2}; + addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors); + auto send_function = autogradContext.sendFunctions()[1]; // Build loss and attach it as input to send autograd function. auto loss = torch::autograd::Variable(torch::ones({3, 3})); diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 9f15573cde048..30bce094ca2b8 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -9,7 +9,7 @@ namespace torch { namespace jit { -void testLiteInterpreter() { +void testLiteInterpreterAdd() { script::Module m("m"); m.register_parameter("foo", torch::ones({}), false); // TODO: support default param val, which was pushed in @@ -43,5 +43,32 @@ void testLiteInterpreter() { AT_ASSERT(resd == refd); } +void testLiteInterpreterConv() { + std::vector inputs; + + script::Module m("m"); + m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); + m.register_parameter("bias", torch::ones({20}), false); + m.define(R"( + def forward(self, input): + return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True) + )"); + + inputs.push_back(torch::ones({1, 1, 28, 28})); + + auto outputref = m.forward(inputs).toTensor(); + + std::stringstream ss; + m._save_for_mobile(ss); + mobile::Module bc = _load_for_mobile(ss); + IValue res; + for (int i = 0; i < 3; ++i) { + auto bcinputs = inputs; + res = bc.run_method("forward", bcinputs); + } + auto output = res.toTensor(); + AT_ASSERT(outputref.dim() == output.dim()); + AT_ASSERT(outputref[0][0][0][0].item() == output[0][0][0][0].item()); +} } // namespace torch } // namespace jit diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index d131d3b80bd87..27e5a376c7b42 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -1049,7 +1049,7 @@ void testInsertAndEliminateRedundantGuards() { checkShape(*guard, {2, 3}, false); auto is_guard = [](Node* n) { return n->kind() == prim::Guard; }; int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard); - ASSERT_EQ(num_guards, 11); + ASSERT_EQ(num_guards, 12); // now eliminate as many guards as possible // we should be left with two guards on x and y's defs EliminateRedundantGuards(copy); diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 174e8bb115b93..5efb371a6b44e 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -65,7 +65,8 @@ namespace jit { _(ImportTooNew) \ _(ClassDerive) \ _(Inliner) \ - _(LiteInterpreter) + _(LiteInterpreterAdd) \ + _(LiteInterpreterConv) #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ diff --git a/test/data/test_cuda_ignores.txt b/test/data/test_cuda_ignores.txt deleted file mode 100644 index afad1c57906e5..0000000000000 --- a/test/data/test_cuda_ignores.txt +++ /dev/null @@ -1,112 +0,0 @@ -# List of functions that are not implemented on CUDA tensors -# These are skipped by test_cuda.py -torch.ByteTensor.dist -torch.ByteTensor.dot -torch.ByteTensor.lerp -torch.ByteTensor.lerp_ -torch.ByteTensor.mean -torch.ByteTensor.norm -torch.ByteTensor.renorm -torch.ByteTensor.renorm_ -torch.ByteTensor.std -torch.ByteTensor.var -torch.CharTensor.dist -torch.CharTensor.dot -torch.CharTensor.lerp -torch.CharTensor.lerp_ -torch.CharTensor.mean -torch.CharTensor.norm -torch.CharTensor.renorm -torch.CharTensor.renorm_ -torch.CharTensor.std -torch.CharTensor.var -torch.HalfTensor.chunk_ -torch.HalfTensor.clone_ -torch.HalfTensor.contiguous_ -torch.HalfTensor.cross_ -torch.HalfTensor.cumprod_ -torch.HalfTensor.cumsum_ -torch.HalfTensor.dim_ -torch.HalfTensor.dist_ -torch.HalfTensor.dot_ -torch.HalfTensor.element_size_ -torch.HalfTensor.equal_ -torch.HalfTensor.expand_ -torch.HalfTensor.expand_as_ -torch.HalfTensor.eye -torch.HalfTensor.eye_ -torch.HalfTensor.fill -torch.HalfTensor.geqrf -torch.HalfTensor.geqrf_ -torch.HalfTensor.inverse -torch.HalfTensor.inverse_ -torch.HalfTensor.is_contiguous_ -torch.HalfTensor.is_same_size_ -torch.HalfTensor.is_set_to_ -torch.HalfTensor.kthvalue_ -torch.HalfTensor.max_ -torch.HalfTensor.mean_ -torch.HalfTensor.min_ -torch.HalfTensor.mode_ -torch.HalfTensor.narrow_ -torch.HalfTensor.ndimension_ -torch.HalfTensor.nelement_ -torch.HalfTensor.nonzero_ -torch.HalfTensor.norm_ -torch.HalfTensor.numel_ -torch.HalfTensor.ones -torch.HalfTensor.ones_ -torch.HalfTensor.permute_ -torch.HalfTensor.prod_ -torch.HalfTensor.put__ -torch.HalfTensor.qr -torch.HalfTensor.qr_ -torch.HalfTensor.repeat_ -torch.HalfTensor.size_ -torch.HalfTensor.sort_ -torch.HalfTensor.split_ -torch.HalfTensor.std_ -torch.HalfTensor.sum_ -torch.HalfTensor.take_ -torch.HalfTensor.to_list -torch.HalfTensor.to_list_ -torch.HalfTensor.topk_ -torch.HalfTensor.trace_ -torch.HalfTensor.trigamma -torch.HalfTensor.trigamma_ -torch.HalfTensor.var_ -torch.HalfTensor.view_ -torch.HalfTensor.view_as_ -torch.HalfTensor.zero -torch.HalfTensor.zeros -torch.HalfTensor.zeros_ -torch.IntTensor.dist -torch.IntTensor.dot -torch.IntTensor.lerp -torch.IntTensor.lerp_ -torch.IntTensor.mean -torch.IntTensor.norm -torch.IntTensor.renorm -torch.IntTensor.renorm_ -torch.IntTensor.std -torch.IntTensor.var -torch.LongTensor.dist -torch.LongTensor.dot -torch.LongTensor.lerp -torch.LongTensor.lerp_ -torch.LongTensor.mean -torch.LongTensor.norm -torch.LongTensor.renorm -torch.LongTensor.renorm_ -torch.LongTensor.std -torch.LongTensor.var -torch.ShortTensor.dist -torch.ShortTensor.dot -torch.ShortTensor.lerp -torch.ShortTensor.lerp_ -torch.ShortTensor.mean -torch.ShortTensor.norm -torch.ShortTensor.renorm -torch.ShortTensor.renorm_ -torch.ShortTensor.std -torch.ShortTensor.var diff --git a/test/dist_autograd_test.py b/test/dist_autograd_test.py index 6305cd526b841..d77100887bbdf 100644 --- a/test/dist_autograd_test.py +++ b/test/dist_autograd_test.py @@ -1,65 +1,75 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import sys -import torch.distributed as dist -import torch.distributed.autograd as dist_autograd -from functools import wraps -import six +import time import unittest + import torch +import torch.distributed as dist +import torch.distributed.autograd as dist_autograd +from dist_utils import INIT_METHOD_TEMPLATE, dist_init -if not dist.is_available(): - print("c10d not available, skipping tests") - sys.exit(0) - -def dist_init(func): - """ - We use this decorator for setting up and tearing down state since - MultiProcessTestCase runs each `test*` method in a separate process and - each process just runs the `test*` method without actually calling - 'setUp' and 'tearDown' methods of unittest. - """ - @wraps(func) - def wrapper(self): - self.worker_id = self.rank - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group(backend='gloo', rank=self.rank, - world_size=self.world_size, store=store) - dist.init_model_parallel('worker%d' % self.rank) - func(self) - dist.join_rpc() - - return wrapper - -@unittest.skipIf(not six.PY3, "Pytorch distributed autograd package " - "does not support python2") -class DistAutogradTest(object): +prev_rank_rpc_done = False +prev_rank_context_id = 0 + + +def _set_rpc_done(context_id): + global prev_rank_rpc_done + global prev_rank_context_id + prev_rank_rpc_done = True + prev_rank_context_id = context_id + + +@unittest.skipIf( + not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2" +) +class DistAutogradTest(object): @property def world_size(self): return 4 + @property + def init_method(self): + return INIT_METHOD_TEMPLATE.format( + file_name=self.file_name, rank=self.rank, world_size=self.world_size + ) + @dist_init def test_autograd_context(self): + # Verify max possible id. + max_auto_increment = 281474976710655 + self.assertEqual( + max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id() + ) + context_ids = [] for i in range(1000): with dist_autograd.context() as context_id: - self.assertEqual(context_id, dist_autograd._retrieve_context(context_id)._context_id()) + self.assertEqual( + context_id, + dist_autograd._retrieve_context(context_id)._context_id(), + ) # First 16 bits should be worker_id. self.assertEqual(self.worker_id, context_id >> 48) context_ids.append(context_id) for context_id in context_ids: - with self.assertRaisesRegex(RuntimeError, 'Could not find autograd context with id: {}'.format(context_id)): + with self.assertRaisesRegex( + RuntimeError, + "Could not find autograd context with id: {}".format(context_id), + ): dist_autograd._retrieve_context(context_id) @dist_init - def test_autograd_send_function(self): + def test_autograd_functions(self): dst_rank = (self.rank + 1) % self.world_size with dist_autograd.context() as context_id: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) - ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.add, args=(t1, t2)) + ret = dist.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) + dist.rpc_sync( + "worker{}".format(dst_rank), _set_rpc_done, args=(context_id,) + ) # Get send function. ctx = dist_autograd._current_context() @@ -68,17 +78,54 @@ def test_autograd_send_function(self): self.assertEqual(1, len(send_functions)) # Retrieve the next functions in the graph. - next_funcs = send_functions[0].next_functions + next_funcs = list(send_functions.values())[0].next_functions self.assertEqual(2, len(next_funcs)) # We should now hit t1 and t2 in the autograd graph. - self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[0][0].name()) + self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name()) self.assertEqual(t1, next_funcs[0][0].variable) self.assertEqual(0, next_funcs[0][1]) - self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[1][0].name()) + self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name()) self.assertEqual(t2, next_funcs[1][0].variable) self.assertEqual(0, next_funcs[1][1]) + # Test recv functions. + recv_functions = ctx._recv_functions() + self.assertEqual(1, len(recv_functions)) + self.assertEqual(ret.grad_fn, list(recv_functions.values())[0]) + + # We should have send/recv functions from the previous rank, get all + # contexts in this node to find them. + + # Wait for the prev rank to be done with rpc. + while not prev_rank_rpc_done: + time.sleep(0.1) + pass + + # Now verify the autograd graph. + ctx = dist_autograd._retrieve_context(prev_rank_context_id) + + # Get the send function. + send_functions = ctx._send_functions() + self.assertEqual(1, len(send_functions)) + + # Verify next function is AddBackward0 + next_funcs = list(send_functions.values())[0].next_functions + self.assertEqual(1, len(next_funcs)) + add_backward_fn = next_funcs[0][0] + self.assertEqual("AddBackward0", add_backward_fn.name()) + + # Verify the next two functions are the same recv backward function. + next_funcs = add_backward_fn.next_functions + self.assertEqual(2, len(next_funcs)) + self.assertEqual( + "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() + ) + self.assertEqual( + "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name() + ) + self.assertEqual(next_funcs[0][0], next_funcs[1][0]) + # autograd context should be cleaned up by now. with self.assertRaises(RuntimeError): ctx = dist_autograd._retrieve_context(context_id) @@ -95,15 +142,21 @@ def test_rpc_complex_args(self): tensors = [] for i in range(num_tensors): tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) - ret = dist.rpc_sync('worker{}'.format(dst_rank), torch.stack, args=(tensors,)) + ret = dist.rpc_sync( + "worker{}".format(dst_rank), torch.stack, args=(tensors,) + ) self.assertEqual(torch.stack(tensors), ret) # Verify appropriate tensors have been attached the autograd graph. - next_funcs = dist_autograd._current_context()._send_functions()[0].next_functions + next_funcs = list( + dist_autograd._current_context()._send_functions().values() + )[0].next_functions idx = 0 for i in range(num_tensors): if i % 2 == 0: - self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[i][0].name()) + self.assertEqual( + "torch::autograd::AccumulateGrad", next_funcs[i][0].name() + ) self.assertEqual(tensors[i], next_funcs[i][0].variable) else: self.assertIsNone(next_funcs[i][0]) diff --git a/test/dist_utils.py b/test/dist_utils.py new file mode 100644 index 0000000000000..fcdfdefa22343 --- /dev/null +++ b/test/dist_utils.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +from functools import wraps +from os import getenv + +import torch.distributed as dist +from torch.distributed.rpc_api import RpcBackend + + +if not dist.is_available(): + print("c10d not available, skipping tests") + sys.exit(0) + + +class TestConfig: + __slots__ = ['backend'] + + def __init__(self, *args, **kwargs): + assert len(args) == 0, "TestConfig only takes kwargs." + for k, v in kwargs.items(): + setattr(self, k, v) + + +TEST_CONFIG = TestConfig(backend=getenv("RPC_BACKEND", RpcBackend.PROCESS_GROUP)) +INIT_METHOD_TEMPLATE = "file://{file_name}?rank={rank}&world_size={world_size}" + + +def dist_init(test_method): + """ + We use this decorator for setting up and tearing down state since + MultiProcessTestCase runs each `test*` method in a separate process and + each process just runs the `test*` method without actually calling + 'setUp' and 'tearDown' methods of unittest. + """ + + @wraps(test_method) + def wrapper(self, *arg, **kwargs): + self.worker_id = self.rank + dist.init_process_group(backend="gloo", init_method=self.init_method) + dist.init_model_parallel( + self_name="worker%d" % self.rank, + backend=TEST_CONFIG.backend, + self_rank=self.rank, + init_method=self.init_method, + ) + test_method(self, *arg, **kwargs) + dist.join_rpc() + + return wrapper diff --git a/test/onnx/expect/TestOperators.test_fmod.expect b/test/onnx/expect/TestOperators.test_fmod.expect new file mode 100644 index 0000000000000..0fd8adb615cb9 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_fmod.expect @@ -0,0 +1,77 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.3" +graph { + node { + input: "0" + input: "1" + output: "2" + op_type: "Mod" + attribute { + name: "fmod" + i: 1 + type: INT + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/test/onnx/expect/TestOperators.test_remainder.expect b/test/onnx/expect/TestOperators.test_remainder.expect new file mode 100644 index 0000000000000..fe6826b151f2c --- /dev/null +++ b/test/onnx/expect/TestOperators.test_remainder.expect @@ -0,0 +1,89 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.3" +graph { + node { + input: "0" + input: "1" + output: "2" + op_type: "Div" + } + node { + input: "2" + output: "3" + op_type: "Floor" + } + node { + input: "3" + input: "1" + output: "4" + op_type: "Mul" + } + node { + input: "0" + input: "4" + output: "5" + op_type: "Sub" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "5" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_unfold.expect b/test/onnx/expect/TestOperators.test_unfold.expect new file mode 100644 index 0000000000000..653ca7accc654 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_unfold.expect @@ -0,0 +1,121 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.3" +graph { + node { + input: "0" + output: "1" + op_type: "Slice" + attribute { + name: "axes" + ints: 2 + type: INTS + } + attribute { + name: "ends" + ints: 2 + type: INTS + } + attribute { + name: "starts" + ints: 0 + type: INTS + } + } + node { + input: "0" + output: "2" + op_type: "Slice" + attribute { + name: "axes" + ints: 2 + type: INTS + } + attribute { + name: "ends" + ints: 4 + type: INTS + } + attribute { + name: "starts" + ints: 2 + type: INTS + } + } + node { + input: "1" + output: "3" + op_type: "Unsqueeze" + attribute { + name: "axes" + ints: 2 + type: INTS + } + } + node { + input: "2" + output: "4" + op_type: "Unsqueeze" + attribute { + name: "axes" + ints: 2 + type: INTS + } + } + node { + input: "3" + input: "4" + output: "5" + op_type: "Concat" + attribute { + name: "axis" + i: 2 + type: INT + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "5" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 4efb6d9b11a85..83c798f78101b 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -771,6 +771,20 @@ def test_frobenius_norm(self): x = torch.randn(2, 3, 4).float() self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x) + def test_unfold(self): + x = torch.randn(2, 3, 4, requires_grad=True) + self.assertONNX(lambda x: x.unfold(dimension=2, size=2, step=2), x) + + def test_remainder(self): + x = torch.randn(2, 3, 4) + y = torch.randn(2, 1, 4) + self.assertONNX(lambda x, y: torch.remainder(x, y), (x, y)) + + def test_fmod(self): + x = torch.randn(2, 3, 4) + y = torch.randn(2, 1, 4) + self.assertONNX(lambda x, y: torch.fmod(x, y), (x, y), opset_version=10) + def test_gelu(self): x = torch.randn(2, 3, 4, 5, requires_grad=True) self.assertONNX(lambda x: torch.nn.functional.gelu(x), x) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index db616f2659458..32ff6d97c56f6 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -2218,6 +2218,29 @@ def forward(self, x): x = torch.arange(16).view(2, 2, 4).to(torch.float32) self.run_model_test(MaskedFillModel2(), input=(x, ), train=False, batch_size=BATCH_SIZE) + def test_remainder(self): + class RemainderModel(torch.nn.Module): + def forward(self, input, other): + return torch.remainder(input, other) + + x = torch.randn(4, 2, 3) + y = torch.randn(1, 2, 1) + model = RemainderModel() + outputs = model(x, y) + self.run_model_test(model, train=False, input=(x, y), batch_size=BATCH_SIZE, + example_outputs=(outputs,)) + + def test_remainder_scalar(self): + class RemainderModel(torch.nn.Module): + def forward(self, input): + return torch.remainder(input, 2.55) + + inputs = torch.randint(10, (2, 3)) + model = RemainderModel() + outputs = model(inputs) + self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE, + example_outputs=(outputs,)) + def test_baddbmm(self): class MyModule(torch.nn.Module): def forward(self, input, batch1, batch2): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index fc6335def9e9c..3c741fb8845d9 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -10,6 +10,7 @@ import numpy as np import io import itertools +import copy from torch.nn.utils import rnn as rnn_utils from model_defs.lstm_flattening_result import LstmFlatteningResult @@ -57,13 +58,17 @@ def run_model_test(self, model, batch_size=2, state_dict=None, with torch.no_grad(): if isinstance(input, torch.Tensor): input = (input,) - output = model(*input) + # In-place operators will update input tensor data as well. + # Thus inputs are replicated before every forward call. + input_copy = copy.deepcopy(input) + output = model(*input_copy) if isinstance(output, torch.Tensor): output = (output,) # export the model to ONNX f = io.BytesIO() - torch.onnx._export(model, input, f, + input_copy = copy.deepcopy(input) + torch.onnx._export(model, input_copy, f, opset_version=self.opset_version, example_outputs=output, do_constant_folding=do_constant_folding, @@ -74,7 +79,8 @@ def run_model_test(self, model, batch_size=2, state_dict=None, # compute onnxruntime output prediction ort_sess = onnxruntime.InferenceSession(f.getvalue()) - ort_test_with_input(ort_sess, input, output, rtol, atol) + input_copy = copy.deepcopy(input) + ort_test_with_input(ort_sess, input_copy, output, rtol, atol) # if addiional test inputs are provided run the onnx # model with these inputs and check the outputs @@ -82,7 +88,8 @@ def run_model_test(self, model, batch_size=2, state_dict=None, for test_input in test_with_inputs: if isinstance(test_input, torch.Tensor): test_input = (test_input,) - output = model(*test_input) + test_input_copy = copy.deepcopy(test_input) + output = model(*test_input_copy) if isinstance(output, torch.Tensor): output = (output,) ort_test_with_input(ort_sess, test_input, output, rtol, atol) @@ -1107,6 +1114,50 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + def test_unfold(self): + class UnfoldModel(torch.nn.Module): + def forward(self, x): + return x.unfold(dimension=2, size=2, step=2) + + x = torch.randn(4, 2, 3, requires_grad=True) + self.run_test(UnfoldModel(), x) + + def test_remainder(self): + class RemainderModel(torch.nn.Module): + def forward(self, input, other): + return torch.remainder(input, other) + + x = torch.randn(4, 2, 3) + y = torch.randn(1, 2, 1) + self.run_test(RemainderModel(), (x, y)) + + def test_remainder_scalar(self): + class RemainderModel(torch.nn.Module): + def forward(self, input): + return torch.remainder(input, 2.55) + + x = torch.randint(10, (2, 3)) + self.run_test(RemainderModel(), x) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_fmod(self): + class FModModel(torch.nn.Module): + def forward(self, input, other): + return torch.fmod(input, other) + + x = torch.randn(4, 2, 3) + y = torch.randn(1, 2, 1) + self.run_test(FModModel(), (x, y)) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_fmod_scalar(self): + class FModModel(torch.nn.Module): + def forward(self, input): + return torch.fmod(input, 2.55) + + x = torch.randint(10, (2, 3)) + self.run_test(FModModel(), x) + @skipIfUnsupportedMinOpsetVersion(9) def test_gelu(self): class GeluModel(torch.nn.Module): @@ -1116,19 +1167,28 @@ def forward(self, x): x = torch.randn(2, 4, 5, 6, requires_grad=True) self.run_test(GeluModel(), x) + def test_add_inplace(self): + class InplaceAddModel(torch.nn.Module): + def forward(self, x): + x += 12 + return x + + x = torch.randn(4, 2, 3, requires_grad=True) + self.run_test(InplaceAddModel(), x) + def test_rsqrt(self): class RsqrtModel(torch.nn.Module): def forward(self, x): return x.rsqrt() - x = torch.randn(4, 2, 3, requires_grad=True).to(dtype=torch.float64) + x = torch.randn(4, 2, 3, requires_grad=True, dtype=torch.float64) self.run_test(RsqrtModel(), x) def test_rsqrt_zeros(self): class RsqrtModel(torch.nn.Module): def forward(self, x): return x.rsqrt() - x = torch.zeros(4, 2, 3, requires_grad=True).to(dtype=torch.float64) + x = torch.zeros(4, 2, 3, requires_grad=True, dtype=torch.float64) self.run_test(RsqrtModel(), x) # TODO: enable opset 11 test once ORT support for unique is in diff --git a/test/rpc_test.py b/test/rpc_test.py index ba02aff28616e..c5ccb4eb94ac0 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -1,35 +1,42 @@ #!/usr/bin/env python3 from __future__ import absolute_import, division, print_function, unicode_literals -import functools +import concurrent.futures import sys import unittest -from os import getenv +from collections import namedtuple from unittest import mock import torch import torch.distributed as dist import torch.distributed.rpc_backend_registry as rpc_backend_registry -from collections import namedtuple -from torch.distributed.internal_rpc_utils import _internal_rpc_pickler, PythonUDF +from common_utils import load_tests +from dist_utils import INIT_METHOD_TEMPLATE, TEST_CONFIG, dist_init +from torch.distributed import ProcessGroupAgent +from torch.distributed.internal_rpc_utils import PythonUDF, _internal_rpc_pickler +from torch.distributed.rpc_api import RpcBackend -if not dist.is_available(): - print("c10d not available, skipping tests") - sys.exit(0) +def requires_process_group_agent(func): + from torch.distributed.rpc_api import _agent -from common_utils import load_tests -from torch.distributed.rpc_api import RpcBackend + return unittest.skipUnless( + isinstance(_agent, ProcessGroupAgent), + "Only ProcessGroupAgent supports global termination detection", + ) -BACKEND = getenv("RPC_BACKEND", RpcBackend.PROCESS_GROUP) -RPC_INIT_URL = getenv("RPC_INIT_URL", "") +VALUE_FUTURE = concurrent.futures.Future() def stub_init_rpc_backend_handler(self_rank, self_name, init_method): return mock.Mock() # RpcAgent. +def set_value(value): + VALUE_FUTURE.set_result(value) + + # it is used to test python user defined function over rpc # classes and functions are used to test python user defined class and # methods over rpc @@ -42,9 +49,8 @@ def __init__(self): def __getstate__(self): (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize( - PythonUDF(my_tensor_function, - (torch.ones(2, 2), torch.ones(2, 2)), - None)) + PythonUDF(my_tensor_function, (torch.ones(2, 2), torch.ones(2, 2)), None) + ) return (pickled_python_udf, tensors) def __setstate__(self, obj): @@ -81,7 +87,7 @@ def build_complex_tensors(): b = [a, a] c = [b, b] d = [a, b] - e = {a : d} + e = {a: d} return [a, b, c, d, e] @@ -103,6 +109,10 @@ def my_complex_tensor_function(list_input, tensor_class_input, dict_input): return (res, complex_tensors[0], complex_tensors[1], complex_tensors[2]) +def my_rref_function(rref_a, rref_b): + return rref_a.to_here() + rref_b.to_here() + + def no_result(): print("do nothing") @@ -110,24 +120,49 @@ def no_result(): def nested_rpc(dst): return dist.rpc_sync(dst, torch.add, args=(torch.ones(2, 2), 1)) + def multi_layer_nested_async_rpc(dst, world_size, ttl): # this method returns immediately without blocking the callee, but will # generate additional requests. if ttl > 0: current_dst = "worker{}".format(dst) next_dst = (dst + 1) % world_size - dist.rpc( + dist.rpc_async( current_dst, multi_layer_nested_async_rpc, - args=( - next_dst, - world_size, - ttl - 1 - ), - async_call=True + args=(next_dst, world_size, ttl - 1), ) return 0 + +def nested_rref(dst): + return ( + dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1)), + dist.remote(dst, torch.add, args=(torch.ones(2, 2), 2)), + ) + + +def nested_remote(dst): + rref = dist.remote(dst, torch.add, args=(torch.ones(2, 2), 3)) + return rref.to_here() + + +def rref_forward_chain(dst, world_size, rref, ttl): + if ttl > 0: + current_dst = "worker{}".format(dst) + next_dst = (dst + 1) % world_size + ret_rref = dist.remote( + current_dst, rref_forward_chain, args=(next_dst, world_size, rref, ttl - 1) + ) + return [ret_rref] + else: + return rref.to_here() + + +def rpc_return_rref(dst): + return dist.remote(dst, torch.add, args=(torch.ones(2, 2), 1)) + + def light_rpc(): return 0 @@ -148,32 +183,6 @@ def raise_func(): load_tests = load_tests -def _wrap_with_rpc(test_method): - """ - We use this decorator for setting up and tearing down state since - MultiProcessTestCase runs each `test*` method in a separate process and - each process just runs the `test*` method without actually calling - 'setUp' and 'tearDown' methods of unittest. - """ - - @functools.wraps(test_method) - def wrapper(self, *arg, **kwargs): - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="gloo", rank=self.rank, world_size=self.world_size, store=store - ) - dist.init_model_parallel( - self_name="worker%d" % self.rank, - backend=BACKEND, - self_rank=self.rank, - init_method=RPC_INIT_URL, - ) - test_method(self, *arg, **kwargs) - dist.join_rpc() - - return wrapper - - @unittest.skipIf( sys.version_info < (3, 0), "Pytorch distributed rpc package " "does not support python2", @@ -183,28 +192,34 @@ class RpcTest(object): def world_size(self): return 4 - @_wrap_with_rpc + @property + def init_method(self): + return INIT_METHOD_TEMPLATE.format( + file_name=self.file_name, rank=self.rank, world_size=self.world_size + ) + + @dist_init def test_worker_id(self): n = self.rank + 1 peer_rank = n % self.world_size - self_worker_id = dist.get_worker_id() - peer_worker_id = dist.get_worker_id("worker{}".format(peer_rank)) + self_worker_info = dist.get_worker_info() + peer_worker_info = dist.get_worker_info("worker{}".format(peer_rank)) - self.assertEqual(self_worker_id.name, "worker{}".format(self.rank)) - self.assertEqual(peer_worker_id.name, "worker{}".format(peer_rank)) + self.assertEqual(self_worker_info.name, "worker{}".format(self.rank)) + self.assertEqual(peer_worker_info.name, "worker{}".format(peer_rank)) with self.assertRaisesRegex(RuntimeError, "Unknown destination worker"): - unknown_worker_id = dist.get_worker_id("WorkerUnknown") + unknown_worker_id = dist.get_worker_info("WorkerUnknown") - @_wrap_with_rpc + @dist_init def test_self_add(self): - self_worker_id = dist.get_worker_id() + self_worker_info = dist.get_worker_info() self_worker_name = "worker{}".format(self.rank) with self.assertRaisesRegex( RuntimeError, "does not support making RPC calls to self" ): - dist.rpc_sync(self_worker_id, torch.add, args=(torch.ones(2, 2), 1)) + dist.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) with self.assertRaisesRegex( RuntimeError, "does not support making RPC calls to self" @@ -212,7 +227,7 @@ def test_self_add(self): dist.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1)) @mock.patch.object(torch.distributed.autograd, "_init") - @mock.patch.object(torch.distributed.rpc_api, "init_rref_context") + @mock.patch.object(torch.distributed.rpc_api, "_init_rref_context") def test_register_rpc_backend_and_init_rpc_backend( self, mock_init_rref_context, mock_dist_autograd_init ): @@ -223,40 +238,34 @@ def test_register_rpc_backend_and_init_rpc_backend( dist.init_model_parallel(self_name="worker1", backend=backend_name, self_rank=1) @unittest.skipIf( - BACKEND != RpcBackend.PROCESS_GROUP, + TEST_CONFIG.backend != RpcBackend.PROCESS_GROUP, "PROCESS_GROUP rpc backend specific test, skip", ) def test_duplicate_name(self): - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="gloo", rank=self.rank, world_size=self.world_size, store=store - ) + dist.init_process_group(backend="gloo", init_method=self.init_method) with self.assertRaisesRegex(RuntimeError, "is not unique"): dist.init_model_parallel( self_name="duplicate_name", - backend=BACKEND, + backend=TEST_CONFIG.backend, self_rank=self.rank, - init_method=RPC_INIT_URL, + init_method=self.init_method, ) dist.join_rpc() def test_reinit(self): - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="gloo", rank=self.rank, world_size=self.world_size, store=store - ) + dist.init_process_group(backend="gloo", init_method=self.init_method) dist.init_model_parallel( self_name="worker{}".format(self.rank), - backend=BACKEND, + backend=TEST_CONFIG.backend, self_rank=self.rank, - init_method=RPC_INIT_URL, + init_method=self.init_method, ) with self.assertRaisesRegex(RuntimeError, "is already initialized"): dist.init_model_parallel( self_name="worker{}".format(self.rank), - backend=BACKEND, + backend=TEST_CONFIG.backend, self_rank=self.rank, - init_method=RPC_INIT_URL, + init_method=self.init_method, ) dist.join_rpc() @@ -266,15 +275,12 @@ def test_init_invalid_backend(self): self_name="worker{}".format(self.rank), backend="invalid", self_rank=self.rank, - init_method=RPC_INIT_URL, + init_method=self.init_method, ) @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/25912") def test_invalid_names(self): - store = dist.FileStore(self.file_name, self.world_size) - dist.init_process_group( - backend="gloo", rank=self.rank, world_size=self.world_size, store=store - ) + dist.init_process_group(backend="gloo", init_method=self.init_method) with self.assertRaisesRegex(RuntimeError, "Worker name must match"): dist.init_model_parallel(self_name="abc*") @@ -286,17 +292,17 @@ def test_invalid_names(self): dist.init_model_parallel(self_name="") # If the number in the message does not match, it is likely that the - # value of MAX_NAME_LEN in RPC WorkerId has changed. + # value of MAX_NAME_LEN in RPC WorkerInfo has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): dist.init_model_parallel( self_name="".join(["a" for _ in range(500)]), - backend=BACKEND, + backend=TEST_CONFIG.backend, self_rank=self.rank, - init_method=RPC_INIT_URL, + init_method=self.init_method, ) dist.join_rpc() - @_wrap_with_rpc + @dist_init def test_add(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -307,18 +313,18 @@ def test_add(self): ) self.assertEqual(ret, torch.ones(n, n) * 2) - @_wrap_with_rpc + @dist_init def test_add_with_id(self): n = self.rank + 1 dst_rank = n % self.world_size - workder_id = dist.get_worker_id("worker{}".format(dst_rank)) + workder_info = dist.get_worker_info("worker{}".format(dst_rank)) ret = dist.rpc_sync( - workder_id, torch.add, args=(torch.ones(n, n), torch.ones(n, n)) + workder_info, torch.add, args=(torch.ones(n, n), torch.ones(n, n)) ) self.assertEqual(ret, torch.ones(n, n) * 2) - @_wrap_with_rpc + @dist_init def test_scalar_add(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -327,7 +333,7 @@ def test_scalar_add(self): ) self.assertEqual(ret, (torch.ones(n, n) + n)) - @_wrap_with_rpc + @dist_init def test_async_add(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -338,7 +344,7 @@ def test_async_add(self): ) self.assertEqual(fut.wait(), torch.ones(n, n) * 2) - @_wrap_with_rpc + @dist_init def test_nonzero(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -347,7 +353,7 @@ def test_nonzero(self): ret = dist.rpc_sync("worker{}".format(dst_rank), torch.nonzero, args=(x,)) self.assertEqual(ret, x.nonzero()) - @_wrap_with_rpc + @dist_init def test_multi_rpc(self): dst_rank = (self.rank + 1) % self.world_size for i in range(20): @@ -359,7 +365,7 @@ def test_multi_rpc(self): ) self.assertEqual(ret, torch.ones(n, n) * 2) - @_wrap_with_rpc + @dist_init def test_sync_rpc(self): dst_rank = (self.rank + 1) % self.world_size for i in range(20): @@ -378,7 +384,7 @@ def test_sync_rpc(self): self.assertEqual(ret1, torch.ones(n, n) * 2) self.assertEqual(ret2, torch.ones(n, n) * 3) - @_wrap_with_rpc + @dist_init def test_join_rpc(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -400,14 +406,22 @@ def test_join_rpc(self): # it's safe to call join_rpc() multiple times dist.join_rpc() - @_wrap_with_rpc + @dist_init + def test_expected_src(self): + dst_rank = (self.rank + 1) % self.world_size + expected_src_rank = (self.rank - 1) % self.world_size + ret = dist.rpc_sync("worker{}".format(dst_rank), set_value, args=(self.rank,)) + value = VALUE_FUTURE.result() + self.assertEqual(value, expected_src_rank) + + @dist_init def test_py_built_in(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync("worker{}".format(dst_rank), min, args=(n, n + 1, n + 2)) self.assertEqual(ret, min(n, n + 1, n + 2)) - @_wrap_with_rpc + @dist_init def test_py_user_defined(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -418,14 +432,14 @@ def test_py_user_defined(self): ) self.assertEqual(ret, my_function(n, n + 1, n + 2)) - @_wrap_with_rpc + @dist_init def test_py_class_constructor(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync("worker{}".format(dst_rank), MyClass, args=(n,)) self.assertEqual(ret.a, n) - @_wrap_with_rpc + @dist_init def test_py_class_instance_method(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -434,7 +448,7 @@ def test_py_class_instance_method(self): ) self.assertEqual(ret, MyClass(2).my_instance_method(n)) - @_wrap_with_rpc + @dist_init def test_py_class_method(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -443,7 +457,7 @@ def test_py_class_method(self): ) self.assertEqual(ret, MyClass.my_class_method(n, n + 1)) - @_wrap_with_rpc + @dist_init def test_py_class_static_method(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -452,86 +466,89 @@ def test_py_class_static_method(self): ) self.assertEqual(ret, MyClass.my_static_method(n + 10)) - @_wrap_with_rpc + @dist_init def test_py_multi_async_call(self): n = self.rank + 1 dst_rank = n % self.world_size - dst_worker_id = dist.get_worker_id("worker{}".format(dst_rank)) - fut1 = dist.rpc_async(dst_worker_id, MyClass.my_static_method, args=(n + 10,)) - fut2 = dist.rpc_async(dst_worker_id, min, args=(n, n + 1, n + 2)) + dst_worker_info = dist.get_worker_info("worker{}".format(dst_rank)) + fut1 = dist.rpc_async(dst_worker_info, MyClass.my_static_method, args=(n + 10,)) + fut2 = dist.rpc_async(dst_worker_info, min, args=(n, n + 1, n + 2)) self.assertEqual(fut1.wait(), MyClass.my_static_method(n + 10)) self.assertEqual(fut2.wait(), min(n, n + 1, n + 2)) - @_wrap_with_rpc + @dist_init def test_py_no_return_result(self): n = self.rank + 1 dst_rank = n % self.world_size ret = dist.rpc_sync("worker{}".format(dst_rank), no_result) self.assertEqual(ret, no_result()) - @_wrap_with_rpc + @dist_init def test_py_tensors(self): n = self.rank + 1 dst_rank = n % self.world_size - ret = dist.rpc("worker{}".format(dst_rank), - my_tensor_function, - args=(torch.ones(n, n), torch.ones(n, n))) - self.assertEqual(ret, - my_tensor_function(torch.ones(n, n), - torch.ones(n, n))) - - @_wrap_with_rpc + ret = dist.rpc_sync( + "worker{}".format(dst_rank), + my_tensor_function, + args=(torch.ones(n, n), torch.ones(n, n)), + ) + self.assertEqual(ret, my_tensor_function(torch.ones(n, n), torch.ones(n, n))) + + @dist_init def test_py_tensors_multi_async_call(self): futs = [] n = self.rank + 1 dst_rank = n % self.world_size for i in range(100): - fut = dist.rpc("worker{}".format(dst_rank), - my_tensor_function, - args=(torch.ones(i, i), torch.ones(i, i)), - async_call=True) + fut = dist.rpc_async( + "worker{}".format(dst_rank), + my_tensor_function, + args=(torch.ones(i, i), torch.ones(i, i)), + ) futs.append(fut) j = 0 for fut in futs: - self.assertEqual(fut.wait(), - my_tensor_function(torch.ones(j, j), - torch.ones(j, j))) + self.assertEqual( + fut.wait(), my_tensor_function(torch.ones(j, j), torch.ones(j, j)) + ) j += 1 - @_wrap_with_rpc + @dist_init def test_py_tensors_in_container(self): n = self.rank + 1 dst_rank = n % self.world_size a = [torch.ones(n, n), torch.ones(n, n)] b = TensorClass(build_complex_tensors()) c = {"foo": torch.ones(n, n), "bar": torch.ones(n, n)} - ret = dist.rpc("worker{}".format(dst_rank), - my_complex_tensor_function, - args=(a, b, c)) + ret = dist.rpc_sync( + "worker{}".format(dst_rank), my_complex_tensor_function, args=(a, b, c) + ) self.assertEqual(ret, my_complex_tensor_function(a, b, c)) - @_wrap_with_rpc + @dist_init def test_py_nested_pickle(self): n = self.rank + 1 dst_rank = n % self.world_size - ret = dist.rpc("worker{}".format(dst_rank), - run_nested_pickle, - args=(MyPickleClass(), torch.ones(2, 2))) + ret = dist.rpc_sync( + "worker{}".format(dst_rank), + run_nested_pickle, + args=(MyPickleClass(), torch.ones(2, 2)), + ) m = MyPickleClass() m.set(my_tensor_function(torch.ones(2, 2), torch.ones(2, 2))) self.assertEqual(ret, run_nested_pickle(m, torch.ones(2, 2))) - @_wrap_with_rpc + @dist_init def test_py_function_exception(self): n = self.rank + 1 dst_rank = n % self.world_size with self.assertRaisesRegex(Exception, "TypeError"): ret = dist.rpc_sync("worker{}".format(dst_rank), no_result, args=(10,)) - @_wrap_with_rpc + @dist_init def test_py_raise_in_user_func(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -539,7 +556,7 @@ def test_py_raise_in_user_func(self): with self.assertRaisesRegex(Exception, "ValueError"): fut.wait() - @_wrap_with_rpc + @dist_init def test_nested_rpc(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -570,15 +587,15 @@ def _stress_test_rpc(self, f, repeat=1000, args=()): ) ) - @_wrap_with_rpc + @dist_init def test_stress_light_rpc(self): self._stress_test_rpc(light_rpc) - @_wrap_with_rpc + @dist_init def test_stress_heavy_rpc(self): self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) - @_wrap_with_rpc + @dist_init def test_builtin_remote_ret(self): n = self.rank + 1 dst_rank = n % self.world_size @@ -589,8 +606,7 @@ def test_builtin_remote_ret(self): ) self.assertEqual(rref.to_here(), torch.ones(n, n) * 2) - @_wrap_with_rpc - def test_multi_builtin_remote_ret(self): + def _test_multi_remote_call(self, fn, args_fn=lambda x: (), kwargs_fn=lambda x: {}): m = 10 n = self.rank + 1 dst_rank = n % self.world_size @@ -601,16 +617,149 @@ def test_multi_builtin_remote_ret(self): rrefs.append( dist.remote( "worker{}".format(dst_rank), - torch.add, - args=(torch.ones(n, n), torch.ones(n, n)), + fn, + args=args_fn(n), + kwargs=kwargs_fn(n), ) ) - expected.append(torch.ones(n, n) * 2) + expected.append(fn(*args_fn(n), **kwargs_fn(n))) for i in range(m): self.assertEqual(rrefs[i].to_here(), expected[i]) - @_wrap_with_rpc + @dist_init + @requires_process_group_agent + def test_multi_builtin_remote_ret(self): + def args_fn(n): + return (torch.ones(n, n), torch.ones(n, n)) + + self._test_multi_remote_call(torch.add, args_fn=args_fn) + + @dist_init + @requires_process_group_agent + def test_py_udf_remote(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref = dist.remote( + "worker{}".format(dst_rank), + my_function, + kwargs={"a": n, "b": n + 1, "c": n + 2}, + ) + self.assertEqual(rref.to_here(), my_function(n, n + 1, n + 2)) + + @dist_init + @requires_process_group_agent + def test_multi_py_udf_remote(self): + def kwargs_fn(n): + return {"a": torch.ones(n, n), "b": torch.ones(n, n), "c": torch.ones(n, n)} + + self._test_multi_remote_call(my_function, kwargs_fn=kwargs_fn) + + @dist_init + @requires_process_group_agent + def test_py_rref_args(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_a = dist.remote( + "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2) + ) + rref_b = dist.remote( + "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1) + ) + rref_c = dist.remote( + "worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b) + ) + self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) + + @dist_init + @requires_process_group_agent + def test_py_rref_args_user_share(self): + n = self.rank + 1 + owner_rank = n % self.world_size + user_rank = (n + 1) % self.world_size + rref_a = dist.remote( + "worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 2, 0) + ) + rref_b = dist.remote( + "worker{}".format(owner_rank), my_function, args=(torch.ones(n, n), 1, 0) + ) + rref_c = dist.remote( + "worker{}".format(user_rank), my_rref_function, args=(rref_a, rref_b) + ) + self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) + + @dist_init + @requires_process_group_agent + def test_py_rpc_rref_args(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_a = dist.remote( + "worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 2, 0) + ) + rref_b = dist.remote( + "worker{}".format(dst_rank), my_function, args=(torch.ones(n, n), 1, 0) + ) + + c = dist.rpc_sync( + "worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b) + ) + + self.assertEqual(c, torch.ones(n, n) + 4) + + @dist_init + @requires_process_group_agent + def test_nested_remote(self): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + rref = dist.remote( + "worker{}".format(dst_rank1), + nested_remote, + args=("worker{}".format(dst_rank2),), + ) + self.assertEqual(rref.to_here(), torch.ones(2, 2) + 3) + + @dist_init + @requires_process_group_agent + def test_nested_rref(self): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + rref_of_rrefs = dist.remote( + "worker{}".format(dst_rank1), + nested_rref, + args=("worker{}".format(dst_rank2),), + ) + rrefs = rref_of_rrefs.to_here() + self.assertEqual(len(rrefs), 2) + self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) + self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2) + + @dist_init + @requires_process_group_agent + def test_nested_rref_stress(self): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + all_rrefs = [] + for _ in range(20): + all_rrefs.append( + dist.remote( + "worker{}".format(dst_rank1), + nested_rref, + args=("worker{}".format(dst_rank2),), + ) + ) + + for i in range(20): + rref_of_rrefs = all_rrefs[i] + rrefs = rref_of_rrefs.to_here() + self.assertEqual(len(rrefs), 2) + self.assertEqual(rrefs[0].to_here(), torch.ones(2, 2) + 1) + self.assertEqual(rrefs[1].to_here(), torch.ones(2, 2) + 2) + + @dist_init + @requires_process_group_agent def test_multi_layer_nested_async_rpc(self): # This test will exit right away, but there will be a chain of async # RPCs. The termination algorithm should detect those messages properly. @@ -621,3 +770,61 @@ def test_multi_layer_nested_async_rpc(self): dst_rank = n % self.world_size multi_layer_nested_async_rpc(dst_rank, self.world_size, ttl) + + @dist_init + @requires_process_group_agent + def test_remote_with_exception(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref = dist.remote("worker{}".format(dst_rank), raise_func) + with self.assertRaisesRegex(Exception, "ValueError"): + rref.to_here() + + @dist_init + @requires_process_group_agent + def test_rpc_return_rref(self): + n = self.rank + 1 + dst_rank1 = n % self.world_size + dst_rank2 = (n + 1) % self.world_size + rref = dist.rpc_sync( + "worker{}".format(dst_rank1), + rpc_return_rref, + args=("worker{}".format(dst_rank2),), + ) + self.assertEqual(rref.to_here(), torch.ones(2, 2) + 1) + + @dist_init + @requires_process_group_agent + def test_rref_forward_chain(self): + ttl = 8 + n = self.rank + 1 + dst_rank = n % self.world_size + + rref = dist.remote( + "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1) + ) + + ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) + + for i in range(ttl): + self.assertEqual(len(ret_rref), 1) + ret_rref = ret_rref[0].to_here() + + ret = ret_rref + self.assertEqual(ret, torch.add(torch.ones(n, n), 1)) + + @dist_init + @requires_process_group_agent + def test_remote_same_worker(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref_a = dist.remote( + "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 2) + ) + rref_b = dist.remote( + "worker{}".format(dst_rank), torch.add, args=(torch.ones(n, n), 1) + ) + rref_c = dist.remote( + "worker{}".format(dst_rank), my_rref_function, args=(rref_a, rref_b) + ) + self.assertEqual(rref_c.to_here(), torch.ones(n, n) + 4) diff --git a/test/run_test.py b/test/run_test.py index f32e09c65c0b5..7788ec4e7a82c 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -28,7 +28,6 @@ 'cuda_primary_ctx', 'dataloader', 'dist_autograd_fork', - 'dist_autograd_spawn', 'distributed', 'distributions', 'docs_coverage', @@ -64,12 +63,19 @@ 'function_schema', ] -# skip < 3.6 b/c fstrings added in 3.6 +# skip < 3.6 b/c fstrings added in 3.6 for jit_py3 +# skip < 3.6 for rpc_spawn and dist_autograd_spawn temporarily because +# a segmenation fault was triggered on python 3.5, +# rpc_spawn and dist_autograd_spawn tests were added in +# https://github.com/pytorch/pytorch/pull/25656 +# skip < 3.6 for rpc_fork as it imports mock that is only available in 3.6, mock +# was added to rpc_fork in https://github.com/pytorch/pytorch/pull/26997 if PY36: TESTS.extend([ 'jit_py3', 'rpc_fork', 'rpc_spawn', + 'dist_autograd_spawn', ]) WINDOWS_BLACKLIST = [ diff --git a/test/test_autograd.py b/test/test_autograd.py index 6e17cc737ad5d..1854b60144524 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4,7 +4,6 @@ import math import tempfile import time -import torch import unittest import warnings from copy import deepcopy @@ -12,6 +11,12 @@ from itertools import product from operator import mul from functools import reduce +import torch + +# TODO: remove this global setting +# Autograd tests use double as the default dtype +torch.set_default_dtype(torch.double) + from torch import nn from torch._six import inf, nan, istuple from torch.autograd.gradcheck import gradgradcheck, gradcheck diff --git a/test/test_c10d.py b/test/test_c10d.py index 5e9584f94f5dc..509c2478fc5f9 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -26,9 +26,9 @@ from torch.nn.parallel import DistributedDataParallel from common_distributed import MultiProcessTestCase, \ - requires_gloo, requires_nccl, \ + requires_gloo, requires_nccl, requires_nccl_version, \ skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues, get_timeout -from common_utils import TestCase, load_tests, run_tests +from common_utils import TestCase, load_tests, run_tests, default_floating_dtype from common_utils import retry_on_address_already_in_use_error # load_tests from common_utils is used to automatically filter tests for @@ -2070,6 +2070,7 @@ def test_dist_broadcast_coalesced_gloo(self): @requires_gloo() @skip_if_not_multigpu + @default_floating_dtype(torch.double) def test_sync_params_no_buffers(self): store = c10d.FileStore(self.file_name, self.world_size) options = c10d.ProcessGroupGloo.Options() @@ -2097,6 +2098,7 @@ def test_sync_params_no_buffers(self): @requires_gloo() @skip_if_not_multigpu + @default_floating_dtype(torch.double) def test_sync_params_with_buffers(self): store = c10d.FileStore(self.file_name, self.world_size) options = c10d.ProcessGroupGloo.Options() @@ -2952,10 +2954,9 @@ def test_multi_limit_multi_dtype(self): result = dist._compute_bucket_assignment_by_size(tensors, [200, 400]) self.assertEqual([[0], [1], [2, 4], [3, 5]], result) - -class CommTest(MultiProcessTestCase): +class NcclErrorHandlingTest(MultiProcessTestCase): def setUp(self): - super(CommTest, self).setUp() + super(NcclErrorHandlingTest, self).setUp() # Need to skip return code checking for these tests since the child # processes don't exit cleanly. self.skip_return_code_checks = [ @@ -2966,17 +2967,8 @@ def setUp(self): ] self._fork_processes() - def _get_wrapped_func(self, func): - # Get the original function which was wrapped in the decorator. - if hasattr(func, '__wrapped__'): - # py3 way. - return func.__wrapped__ - else: - # py2 way. - return func.func_closure[0].cell_contents - def tearDown(self): - super(CommTest, self).tearDown() + super(NcclErrorHandlingTest, self).tearDown() try: os.remove(self.file_name) except OSError: @@ -2990,38 +2982,20 @@ def op_timeout_sec(self): def world_size(self): return 2 - def _test_broadcast_coalesced(self, process_group, device): - half = torch.float16 - - # No support for float16 for CPU tensors - if device == torch.device('cpu'): - half = torch.float32 - - target = torch.arange(60, dtype=half, device=device).chunk(5) - target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) - target += torch.arange(60, dtype=half, device=device).chunk(5) - target += torch.arange(60, dtype=torch.float64, device=device).chunk(5) - target += torch.arange(60, dtype=half, device=device).chunk(5) - target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) - - # The tensors to pass to broadcast are idential to the target - # only on the process that is the root of the broadcast. - if self.rank == 0: - tensors = list(tensor.clone() for tensor in target) + def _get_wrapped_func(self, func): + # Get the original function which was wrapped in the decorator. + if hasattr(func, '__wrapped__'): + # py3 way. + return func.__wrapped__ else: - tensors = list(torch.empty_like(tensor) for tensor in target) - - c10d._broadcast_coalesced( - process_group, - tensors, - buffer_size=256) - - self.assertEqual(tensors, target) + # py2 way. + return func.func_closure[0].cell_contents def _run_all_reduce(self, pg): pg.allreduce(torch.rand(10).cuda(self.rank)) @requires_nccl() + @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_not_multigpu def test_nccl_errors_nonblocking(self): store = c10d.FileStore(self.file_name, self.world_size) @@ -3060,26 +3034,31 @@ def _test_nccl_errors_blocking(self, func): func() @requires_nccl() + @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_not_multigpu def test_nccl_errors_blocking_clean_exit(self): self._test_nccl_errors_blocking(lambda : sys.exit(0)) @requires_nccl() + @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_not_multigpu def test_nccl_errors_blocking_nonzero_exit(self): self._test_nccl_errors_blocking(lambda : sys.exit(1)) @requires_nccl() + @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_not_multigpu def test_nccl_errors_blocking_abort(self): self._test_nccl_errors_blocking(lambda : os.abort()) @requires_nccl() + @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_not_multigpu def test_nccl_errors_blocking_sigkill(self): self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGKILL)) @requires_nccl() + @requires_nccl_version(2400, "Need NCCL 2.4+ for error checking") @skip_if_not_multigpu def test_nccl_errors_blocking_sigterm(self): self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGTERM)) @@ -3098,6 +3077,81 @@ def test_invalid_nccl_blocking_wait_env(self): self._run_invalid_nccl_blocking_wait_env('2147483647') self._run_invalid_nccl_blocking_wait_env('4294967295') + @requires_nccl() + @skip_if_not_multigpu + def test_nccl_timeout(self): + store = c10d.FileStore(self.file_name, self.world_size) + os.environ["NCCL_BLOCKING_WAIT"] = "1" + + # Initialize process_group. + timeout = 1 + c10d.distributed_c10d.init_process_group( + backend=dist.Backend.NCCL, store=store, world_size=2, rank=self.rank, + timeout=timedelta(seconds=timeout)) + c10d.distributed_c10d.all_reduce(torch.rand(10).cuda(self.rank)) + + if self.rank == 0: + # This should timeout in about 1 second. + start = time.time() + with self.assertRaises(RuntimeError): + c10d.distributed_c10d.all_reduce(torch.rand(10).cuda(self.rank)) + + total_time = time.time() - start + + self.assertLess(abs(total_time - timeout), 0.5) + else: + # Ensure the other rank sleeps to trigger timeout. + time.sleep(2 * timeout) + + +class CommTest(MultiProcessTestCase): + def setUp(self): + super(CommTest, self).setUp() + self._fork_processes() + + def tearDown(self): + super(CommTest, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def op_timeout_sec(self): + return 1 + + @property + def world_size(self): + return 2 + + def _test_broadcast_coalesced(self, process_group, device): + half = torch.float16 + + # No support for float16 for CPU tensors + if device == torch.device('cpu'): + half = torch.float32 + + target = torch.arange(60, dtype=half, device=device).chunk(5) + target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) + target += torch.arange(60, dtype=half, device=device).chunk(5) + target += torch.arange(60, dtype=torch.float64, device=device).chunk(5) + target += torch.arange(60, dtype=half, device=device).chunk(5) + target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) + + # The tensors to pass to broadcast are idential to the target + # only on the process that is the root of the broadcast. + if self.rank == 0: + tensors = list(tensor.clone() for tensor in target) + else: + tensors = list(torch.empty_like(tensor) for tensor in target) + + c10d._broadcast_coalesced( + process_group, + tensors, + buffer_size=256) + + self.assertEqual(tensors, target) + @requires_nccl() @skip_if_not_multigpu def test_broadcast_coalesced_nccl(self): diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py index 289b6ff6ba82d..ee4a557e10d7a 100644 --- a/test/test_cpp_extensions.py +++ b/test/test_cpp_extensions.py @@ -9,6 +9,7 @@ import glob import common_utils as common +from common_utils import default_floating_dtype import torch import torch.backends.cudnn import torch.utils.cpp_extension @@ -65,6 +66,7 @@ def test_extension_function(self): z = cpp_extension.sigmoid_add(x, y) self.assertEqual(z, x.sigmoid() + y.sigmoid()) + @default_floating_dtype(torch.double) def test_extension_module(self): mm = cpp_extension.MatrixMultiplier(4, 8) weights = torch.rand(8, 4) @@ -72,6 +74,7 @@ def test_extension_module(self): result = mm.forward(weights) self.assertEqual(expected, result) + @default_floating_dtype(torch.double) def test_backward(self): mm = cpp_extension.MatrixMultiplier(4, 8) weights = torch.rand(8, 4, requires_grad=True) @@ -474,6 +477,7 @@ def compile(code): @dont_wipe_extensions_build_folder @common.skipIfRocm + @default_floating_dtype(torch.double) def test_cpp_frontend_module_has_same_output_as_python(self): extension = torch.utils.cpp_extension.load( name="cpp_frontend_extension", diff --git a/test/test_cuda.py b/test/test_cuda.py index 4842e434e8237..0b0d8b862fe1b 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -21,9 +21,9 @@ from common_methods_invocations import tri_tests_args, tri_large_tests_args, \ _compare_trilu_indices, _compare_large_trilu_indices -from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \ +from common_utils import TestCase, get_gpu_type, freeze_rng_state, run_tests, \ PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, \ - TEST_WITH_ROCM, load_tests, slowTest, skipCUDANonDefaultStreamIf + load_tests, slowTest, skipCUDANonDefaultStreamIf, default_floating_dtype # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -49,25 +49,6 @@ TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9 TEST_MEDIUM_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 6e9 -floating_set = {torch.FloatTensor, torch.DoubleTensor, torch.cuda.FloatTensor, - torch.cuda.DoubleTensor, torch.HalfTensor, torch.cuda.HalfTensor} - - -def is_floating(t): - if not isinstance(t, type): - raise TypeError('t should be an instance of type') - assert t != torch.autograd.Variable - return t in floating_set - - -def is_half(t): - if isinstance(t, torch.Tensor): - return t.dtype == torch.float16 - assert isinstance(t, type) - assert t != torch.autograd.Variable - return t in [torch.HalfTensor, torch.cuda.HalfTensor] - - types = [ torch.FloatTensor, torch.DoubleTensor, @@ -79,54 +60,6 @@ def is_half(t): torch.HalfTensor, ] -signed_types = [ - torch.FloatTensor, - torch.DoubleTensor, - torch.LongTensor, - torch.IntTensor, - torch.ShortTensor, - torch.CharTensor, -] - -unsigned_types = [ - torch.ByteTensor, -] - -float_types = [ - torch.FloatTensor, - torch.DoubleTensor, - torch.HalfTensor, -] - -float_types_no_half = [ - torch.FloatTensor, - torch.DoubleTensor, -] - - -def number(floating, integer, t): - return floating if is_floating(t) else integer - - -def cast_tensor(tensor, t): - return t(tensor.size()).copy_(tensor) - -S = 10 -M = 50 -G = 275000000 - - -def make_tensor(t, *sizes): - if 'Half' in t.__name__: - return t(*sizes).copy_(torch.randn(*sizes)) - else: - tensor = t(*sizes) - if tensor.is_floating_point(): - return tensor.normal_() - else: - return tensor.random_(0, 10) - - def make_sparse_tensor(t, n, *sizes): assert t.is_sparse tensor = t() @@ -137,478 +70,8 @@ def make_sparse_tensor(t, n, *sizes): v = v.new(n).copy_(torch.randn(n)) return t(i, v, torch.Size(sizes)) - -def tensor_clamp(t, min, max): - if is_half(t): - return t.float().clamp(min, max).half() - else: - return t.clamp(min, max) - - -def tensor_mul(t, scale): - if is_half(t): - return t.float().mul(scale).half() - else: - return t.mul(scale) - - -def tensor_abs_(t): - if is_half(t): - return t.float().abs_().half() - else: - return t.abs_() - - -def constant_tensor_sub(a, b): - # helper function to address const - torch.HalfTensor where it doesn't - # have resize_as() - if is_half(b): - return (a - b.float()).half() - else: - return a - b - - -def constant_tensor_add(a, b): - # helper function to address const + torch.HalfTensor where it doesn't - # have add() - if is_half(b): - return (a + b.float()).half() - else: - return a + b - - -def small_0d(t): - return make_tensor(t, (1,)).squeeze() - - -def small_2d(t): - return make_tensor(t, S, S) - - -def small_2d_scaled(t, scale=10): - return tensor_mul(make_tensor(t, S, S), scale) - - -def small_2d_oneish(t): - if is_floating(t): - return tensor_clamp(make_tensor(t, S, S), min=0.99, max=1.01) - else: - return t(S, S).fill_(1) - - -def small_3d(t): - return make_tensor(t, S, S, S) - - -def medium_1d(t): - return make_tensor(t, M) - - -def medium_2d(t): - return make_tensor(t, M, M) - - -def medium_2d_expanded(t): - return t(1).expand(M, M) - - -def medium_2d_scaled(t, scale=10): - return tensor_mul(make_tensor(t, M, M), scale) - - -def small_3d_ones(t): - return t(S, S, S).copy_(torch.ones(S, S, S)) - - -def small_3d_positive(t): - # In div_tensor(), half cannot achieve float precision - min_val = 1e-3 if is_floating(t) and not is_half(t) else 2 - return tensor_clamp(make_tensor(t, S, S, S), min_val, 120) - - -def small_3d_unique(t): - return t(S, S, S).copy_(torch.arange(1, S * S * S + 1).view(S, S, S)) - - -def small_1d_lapack(t): - return t(1, 3).copy_(torch.arange(1, 4).view(3)) - - -def small_2d_lapack(t): - return t(3, 3).copy_(torch.arange(1, 10).view(3, 3)) - - -def small_2d_lapack_skinny(t): - return t(3, 4).copy_(torch.arange(1, 13).view(3, 4)) - - -def small_2d_lapack_fat(t): - return t(4, 3).copy_(torch.arange(1, 13).view(4, 3)) - - -def large_2d_lapack(t): - return t(1000, 1000).normal_() - - -def giant_1d_ones(t): - return t(G).copy_(torch.ones(G)) - - -def long_type(t): - return torch.cuda.LongTensor if 'cuda' in t.__module__ else torch.LongTensor - - -def new_t(*sizes): - def tmp(t): - return t(*sizes).copy_(torch.randn(*sizes)) - return tmp - -# Content of each tuple: -# - function name -# - constructor for the tensor, signature: fn(tensor_type) -> tensor -# - constructor for the arguments, signature: fn(tensor_type) -> list -# - postfix name for the test (must be unique for a given function) (default='') -# - tensor types to use (default=types) -# - disable inplace test, if set to True, no inplace test will be done (default=False) -# - decorator, e.g., unittest.skipIf (default is no decorator) -tests = [ - ('add', small_3d, lambda t: [number(3.14, 3, t)]), - ('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), - ('add', small_3d, lambda t: [number(0.2, 2, t), small_3d_positive(t)], 'scalar_tensor'), - ('sub', small_3d, lambda t: [number(3.14, 3, t)]), - ('sub', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), - ('mul', small_3d, lambda t: [number(3.14, 3, t)]), - ('mul', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), - ('mul', small_0d, lambda t: [small_0d(torch.IntTensor)], 'scalar', types, True), - ('div', small_3d, lambda t: [number(3.14, 3, t)]), - ('div', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), - ('pow', small_3d, lambda t: [number(3.14, 3, t)], None, float_types), - ('pow', small_3d, lambda t: [number(1., 1, t)], 'pow1'), - ('pow', small_3d, lambda t: [number(2., 2, t)], 'pow2'), - ('pow', small_3d, lambda t: [number(3., 3, t)], 'pow3'), - ('pow', small_3d, lambda t: [number(-1., -1, t)], 'pow-1', float_types), - # HalfTensor gives bad result at pow-2 with data sampled from torch.randn - ('pow', small_3d, lambda t: [number(-2., -2, t)], 'pow-2', float_types_no_half, False, - "skipIfRocm:FloatTensor"), - ('pow', small_3d, lambda t: [tensor_abs_(small_3d(t))], 'tensor', float_types), - ('addbmm', small_2d, lambda t: [small_3d(t), small_3d(t)], None, float_types), - ('addbmm', small_2d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'), - ('addbmm', small_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'), - ('baddbmm', small_3d, lambda t: [small_3d(t), small_3d(t)],), - ('baddbmm', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'), - ('baddbmm', small_3d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), small_3d(t), small_3d(t)], 'two_scalars'), - ('bmm', small_3d, lambda t: [small_3d(t)], '', float_types_no_half), - ('addcdiv', small_2d_lapack, lambda t: [tensor_mul(small_2d_lapack(t), 2), small_2d_lapack(t)]), - ('addcdiv', small_2d_lapack, lambda t: [number(2.8, 1, t), tensor_mul(small_2d_lapack(t), 2), small_2d_lapack(t)], - 'scalar'), - ('addcmul', small_3d, lambda t: [small_3d(t), small_3d(t)]), - ('addcmul', small_3d, lambda t: [number(0.4, 2, t), small_3d(t), small_3d(t)], 'scalar'), - ('addmm', medium_2d, lambda t: [medium_2d(t), medium_2d(t)]), - ('addmm', medium_2d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'scalar'), - ('addmm', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_2d(t)], 'two_scalars'), - ('addmv', medium_1d, lambda t: [medium_2d(t), medium_1d(t)],), - ('addmv', medium_1d, lambda t: [number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'scalar'), - ('addmv', medium_1d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_2d(t), medium_1d(t)], 'two_scalars'), - ('addr', medium_2d, lambda t: [medium_1d(t), medium_1d(t)]), - ('addr', medium_2d, lambda t: [number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'scalar'), - ('addr', medium_2d, lambda t: [number(0.5, 3, t), number(0.4, 2, t), medium_1d(t), medium_1d(t)], 'two_scalars'), - ('atan2', medium_2d, lambda t: [medium_2d(t)], None, float_types + [torch.HalfTensor]), - ('fmod', small_3d, lambda t: [3], 'value',), - ('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), - ('chunk', medium_2d, lambda t: [4],), - ('chunk', medium_2d, lambda t: [4, 1], 'dim'), - ('chunk', medium_2d, lambda t: [4, -2], 'neg_dim'), - ('clamp', medium_2d_scaled, lambda t: [-1, 5], None, signed_types), - ('clamp', medium_2d_scaled, lambda t: [1, 5], None, unsigned_types), - ('clone', medium_2d, lambda t: [],), - ('contiguous', medium_2d, lambda t: [],), - ('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)],), - ('cumprod', small_3d, lambda t: [1]), - ('cumprod', small_3d, lambda t: [-1], 'neg_dim'), - ('cumsum', small_3d, lambda t: [1]), - ('cumsum', small_3d, lambda t: [-1], 'neg_dim'), - ('dim', small_3d, lambda t: [],), - ('dist', small_2d, lambda t: [small_2d(t)]), - ('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm'), - ('dist', small_2d, lambda t: [small_2d(t), 2.5], '2_5_norm'), - ('dot', medium_1d, lambda t: [medium_1d(t)], '', types, False, "skipIfRocm:HalfTensor"), - ('element_size', medium_1d, lambda t: [],), - ('eq', small_3d_ones, lambda t: [small_3d(t)],), - ('eq', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'), - ('ne', small_3d_ones, lambda t: [small_3d(t)],), - ('ne', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'), - ('equal', small_3d_ones, lambda t: [small_3d_ones(t)], 'equal'), - ('equal', small_3d_ones, lambda t: [small_3d(t)],), - ('expand', new_t(M, 1, M), lambda t: [M, 4, M],), - ('expand_as', new_t(M, 1, M), lambda t: [new_t(M, 4, M)(t)],), - ('fill', medium_2d, lambda t: [number(3.14, 3, t)]), - ('ge', medium_2d, lambda t: [medium_2d(t)],), - ('le', medium_2d, lambda t: [medium_2d(t)],), - ('gt', medium_2d, lambda t: [medium_2d(t)],), - ('lt', medium_2d, lambda t: [medium_2d(t)],), - ('is_contiguous', medium_2d, lambda t: [],), - # TODO: can't check negative case - GPU copy will be contiguous - ('is_same_size', medium_2d, lambda t: [small_3d(t)], 'negative'), - ('is_same_size', medium_2d, lambda t: [medium_2d(t)], 'positive'), - ('is_set_to', medium_2d, lambda t: [medium_2d(t)],), - # TODO: positive case - ('kthvalue', small_3d_unique, lambda t: [3],), - ('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'), - ('kthvalue', small_3d_unique, lambda t: [3, -1], 'neg_dim'), - ('lerp', small_3d, lambda t: [small_3d(t), 0.3]), - ('max', small_3d_unique, lambda t: []), - ('max', small_3d_unique, lambda t: [1], 'dim'), - ('max', small_3d_unique, lambda t: [-1], 'neg_dim'), - ('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'), - ('min', small_3d_unique, lambda t: []), - ('min', small_3d_unique, lambda t: [1], 'dim'), - ('min', small_3d_unique, lambda t: [-1], 'neg_dim'), - ('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'), - ('mean', small_3d, lambda t: []), - ('mean', small_3d, lambda t: [-1], 'neg_dim'), - ('mean', small_3d, lambda t: [1], 'dim'), - ('mean', giant_1d_ones, lambda t: [], '64bit_indexing', - # Double here because otherwise the CPU result will be - # wrong. - [torch.DoubleTensor]), - ('mode', small_3d, lambda t: []), - ('mode', small_3d, lambda t: [1], 'dim'), - ('mode', small_3d, lambda t: [-1], 'neg_dim'), - ('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.1, 10), lambda t: [1], '2d_p=1', float_types_no_half), - ('mvlgamma', lambda t: tensor_clamp(small_2d(t), 0.6, 10), lambda t: [2], '2d_p=2', float_types_no_half), - ('remainder', small_3d, lambda t: [3], 'value',), - ('remainder', small_3d, lambda t: [-3], 'negative_value', signed_types), - ('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), - ('remainder', small_3d, lambda t: [constant_tensor_sub(0, small_3d_positive(t))], 'negative_tensor', signed_types), - ('std', small_3d, lambda t: []), - ('std', small_3d, lambda t: [1], 'dim', types, False), - ('std', small_3d, lambda t: [-1], 'neg_dim', types, False), - ('var', small_3d, lambda t: []), - ('var', small_3d, lambda t: [1], 'dim'), - ('var', small_3d, lambda t: [-1], 'neg_dim'), - ('ndimension', small_3d, lambda t: [],), - ('nelement', small_3d, lambda t: [],), - ('numel', small_3d, lambda t: [],), - ('narrow', small_3d, lambda t: [1, 3, 2],), - ('narrow', small_3d, lambda t: [-1, 3, 2], 'neg_dim'), - ('nonzero', small_3d, lambda t: [], '', types, False), - ('norm', small_3d, lambda t: []), - ('norm', small_3d, lambda t: [3], '3_norm'), - ('norm', small_3d, lambda t: [3, 0], '3_norm_dim'), - ('norm', small_3d, lambda t: [3, -2], '3_norm_neg_dim'), - ('ones', small_3d, lambda t: [1, 2, 3, 4, 5],), - ('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],), - ('put_', new_t(2, 5, 3), lambda t: [long_type(t)([[0], [-2]]), t([[3], [4]])], '', types, False), - ('put_', new_t(2, 3), lambda t: [long_type(t)([]), t([])], 'empty'), - ('put_', new_t(2, 2), lambda t: [long_type(t)([[1], [-3]]), t([[1], [2]]), True], 'accumulate'), - ('prod', small_2d_oneish, lambda t: []), - ('prod', small_3d, lambda t: [1], 'dim'), - ('prod', small_3d, lambda t: [-1], 'neg_dim'), - ('sum', small_2d, lambda t: []), - ('sum', small_3d, lambda t: [1], 'dim'), - ('sum', small_3d, lambda t: [-1], 'neg_dim'), - ('renorm', small_3d, lambda t: [2, 1, 1], '2_norm'), - ('renorm', small_3d, lambda t: [2, -1, 1], '2_norm_neg_dim'), - ('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm'), - ('repeat', small_2d, lambda t: [2, 2, 2],), - ('size', new_t(1, 2, 3, 4), lambda t: [],), - ('size', new_t(1, 2, 3, 4), lambda t: [1], 'dim'), - ('size', new_t(1, 2, 3, 4), lambda t: [-2], 'neg_dim'), - ('sort', small_3d_unique, lambda t: [], ''), - ('sort', small_3d_unique, lambda t: [1], 'dim'), - ('sort', small_3d_unique, lambda t: [-1], 'neg_dim'), - ('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'), - ('sort', small_3d_unique, lambda t: [-1, True], 'neg_dim_descending'), - ('split', small_3d, lambda t: [2],), - ('split', small_3d, lambda t: [2, 1], 'dim'), - ('split', small_3d, lambda t: [2, -3], 'neg_dim'), - ('squeeze', new_t(1, 2, 1, 4), lambda t: [],), - ('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'), - ('squeeze', new_t(1, 2, 1, 4), lambda t: [-2], 'neg_dim'), - ('t', new_t(1, 2), lambda t: [],), - ('take', new_t(3, 4), lambda t: [long_type(t)([[0], [-2]])], '', types, False), - ('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],), - ('transpose', new_t(1, 2, 3, 4), lambda t: [-1, -2], 'neg_dim'), - ('to_list', small_3d, lambda t: [],), - ('topk', small_3d_unique, lambda t: [2, 1, False, True], 'dim_sort',), - ('topk', small_3d_unique, lambda t: [2, -1, False, True], 'neg_dim_sort',), - ('topk', small_3d_unique, lambda t: [2, 1, True, True], 'dim_desc_sort',), - ('trace', medium_2d, lambda t: []), - ('tril', medium_2d, lambda t: [],), - ('tril', medium_2d_expanded, lambda t: [], 'zero_stride', types, True), - ('tril', medium_2d, lambda t: [2], 'positive'), - ('tril', medium_2d, lambda t: [-2], 'negative'), - ('triu', medium_2d, lambda t: [],), - ('triu', medium_2d_expanded, lambda t: [], 'zero_stride', types, True), - ('triu', medium_2d, lambda t: [2], 'positive'), - ('triu', medium_2d, lambda t: [-2], 'negative'), - ('unsqueeze', new_t(2, 3, 4), lambda t: [2],), - ('unsqueeze', new_t(2, 3, 4), lambda t: [-2], 'neg_dim'), - ('view', small_3d, lambda t: [100, 10], 'contiguous'), - ('view_as', small_3d, lambda t: [make_tensor(t, 100, 10)],), - ('zero', small_3d, lambda t: [],), - ('zeros', small_3d, lambda t: [1, 2, 3, 4],), - ('eye', small_2d, lambda t: [3, 4],), - ('flip', small_3d, lambda t: [0], 'd0', types, True), - ('flip', small_3d, lambda t: [0, 1, 2], 'd012', types, True), - ('flip', small_3d, lambda t: [0, 2], 'd02', types, True), - ('flip', small_3d, lambda t: [2, 0], 'd20', types, True), - ('flip', small_3d, lambda t: [-1], 'neg_d', types, True), - ('rot90', small_2d, lambda t: [1, [0, 1]], 'k1_d01', types, True), - ('rot90', small_3d, lambda t: [1, [1, 2]], 'k1_d12', types, True), - ('rot90', small_3d, lambda t: [1, [1, -1]], 'k1_neg_d', types, True), - ('rot90', small_3d, lambda t: [], 'default', types, True), - ('rsqrt', lambda t: constant_tensor_add(1, small_3d(t)), lambda t: [], None, float_types), - ('sinh', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), - ('tan', lambda t: tensor_clamp(small_3d(t), -1, 1), lambda t: [], None, float_types), - ('__lshift__', lambda t: torch.pow(2, cast_tensor(torch.arange(1, 5), t)), - lambda t: [2], None, signed_types), - ('__rshift__', lambda t: torch.pow(2, cast_tensor(torch.arange(3, 7), t)), - lambda t: [2], None, signed_types), - # lapack tests - ('qr', small_2d_lapack, lambda t: [], 'square', float_types, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('qr', small_2d_lapack_skinny, lambda t: [], 'skinny', float_types, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('qr', large_2d_lapack, lambda t: [], 'big', float_types, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('geqrf', new_t(20, 20), lambda t: [], None, float_types, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('svd', new_t(10, 10), lambda t: [], 'square', float_types_no_half, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('svd', lambda t: new_t(10, 10)(t).t(), lambda t: [True], 'square_col_maj', - float_types_no_half, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('svd', new_t(20, 5), lambda t: [True], 'tall_some', float_types_no_half, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('svd', new_t(20, 5), lambda t: [False], 'tall_all', float_types_no_half, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('svd', lambda t: new_t(5, 20)(t).t(), lambda t: [True], - 'tall_some_col_maj', float_types_no_half, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('svd', lambda t: new_t(5, 20)(t).t(), lambda t: [False], - 'tall_all_col_maj', float_types_no_half, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), - ('eig', new_t(10, 10), lambda t: [True], 'with_eigvec', float_types_no_half, False, - unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")), -] - -# TODO: random functions, cat, gather, scatter, index*, masked*, -# resize, resizeAs, storage_offset, storage, stride, unfold - -custom_precision = { - 'addbmm': 1e-4, - 'addmm': 1e-4, - 'addmv': 1e-4, - 'addr': 1e-4, - 'baddbmm': 1e-4, - 'rsqrt': 1e-4, - 'cumprod': 1e-4, - 'qr': 3e-4, - 'digamma': 1e0, # large values lead to large absolute error but small relative error -} - -custom_half_precision = { - 'add': 1e-2, - 'acos': 1e-3, - 'addbmm': 1e-1, - 'addcdiv': 1e-2, - 'addcmul': 1e-2, - 'addmm': 1e-1, - 'addmv': 1e-2, - 'addr': 1e-2, - 'asin': 1e-3, - 'atan2': 1e-3, - 'atan': 1e-3, - 'baddbmm': 1e-2, - 'cos': 1e-3, - 'cosh': 1e-2, - 'cross': 1e-2, - 'cumprod': 1e-2, - 'cumsum': 1e-2, - 'dist': 1e-2, - 'div': 1e-3, - 'dot': 1e-2, - 'erf': 1e-3, - 'erfc': 1e-3, - 'exp': 1e-2, - 'expm1': 1e-2, - 'fill': 1e-3, - 'lerp': 1e-2, - 'lgamma': 1e-2, - 'log': 1e-2, - 'log10': 1e-2, - 'log1p': 1e-3, - 'log2': 1e-2, - 'mean': 1e-3, - 'mul': 1e-2, - 'norm': 1e-1, - 'pow': 1e-1, - 'prod': 1e-3, - 'reciprocal': 1e-1, - 'remainder': 1e-3, - 'renorm': 1e-3, - 'rsqrt': 1e-2, - 'sigmoid': 1e-3, - 'sin': 1e-3, - 'sinh': 1e-3, - 'sqrt': 1e-3, - 'std': 1e-3, - 'sub': 1e-2, - 'sum': 1e-2, - 'tan': 1e-3, - 'tanh': 1e-3, - 'trace': 1e-3, - 'var': 1e-3, - '__lshift__': 1e-3, - '__rshift__': 1e-3, -} - -simple_pointwise = [ - 'abs', - 'sign', -] -for fn in simple_pointwise: - tests.append((fn, small_3d, lambda t: [])) - -simple_pointwise_float = [ - 'log', - 'log10', - 'log1p', - 'log2', - 'sigmoid', - 'sin', - 'sqrt', - 'tanh', - 'acos', - 'asin', - 'atan', - 'cos', - 'cosh', - 'erf', - 'erfc', - 'exp', - 'expm1', - 'reciprocal', - 'floor', - 'frac', - 'neg', - 'round', - 'trunc', - 'ceil', - 'lgamma', - 'digamma', - 'trigamma', -] - -for fn in simple_pointwise_float: - tests.append((fn, small_3d, lambda t: [], None, float_types)) - _cycles_per_ms = None - def get_cycles_per_ms(): """Approximate number of cycles per millisecond for torch.cuda._sleep""" global _cycles_per_ms @@ -623,43 +86,6 @@ def get_cycles_per_ms(): return _cycles_per_ms -def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5): - def tmp(self): - cpu_tensor = tensor_constructor(t) - gpu_tensor = to_gpu(cpu_tensor) - cpu_args = arg_constructor(t) - gpu_args = [to_gpu(arg) for arg in cpu_args] - if is_half(t): - cpu_tensor = cpu_tensor.float() - cpu_args = [arg.float() if isinstance(arg, torch.Tensor) and is_half(arg) else arg for arg in cpu_args] - cpu_result = getattr(cpu_tensor, fn)(*cpu_args) - try: - gpu_result = getattr(gpu_tensor, fn)(*gpu_args) - except RuntimeError as e: - reason = e.args[0] - data_type_reasons = {'only supports floating-point types', - 'unimplemented data type', - 'not implemented for'} - if any(data_type_reason in reason for data_type_reason in data_type_reasons): - raise unittest.SkipTest('unimplemented data type') - raise - except AttributeError as e: - reason = e.args[0] - if 'object has no attribute' in reason: - raise unittest.SkipTest('unimplemented data type') - raise - # If one changes, another should change as well - self.assertEqual(cpu_tensor, gpu_tensor, precision) - self.assertEqual(cpu_args, gpu_args, precision) - # Compare results - if fn == 'element_size' and t.__name__ == 'HalfTensor': - # Workaround since cpu_result is float - self.assertEqual(2, gpu_result) - else: - self.assertEqual(cpu_result, gpu_result, precision) - return tmp - - class TestCuda(TestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True @@ -1007,9 +433,9 @@ def test_serialization_array_with_storage(self): self.assertEqual(q_copy, q, 0) q_copy[0].fill_(5) self.assertEqual(q_copy[0], q_copy[2], 0) - self.assertTrue(isinstance(q_copy[0], torch.cuda.DoubleTensor)) + self.assertTrue(isinstance(q_copy[0], torch.cuda.FloatTensor)) self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor)) - self.assertTrue(isinstance(q_copy[2], torch.cuda.DoubleTensor)) + self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor)) self.assertTrue(isinstance(q_copy[3], torch.cuda.IntStorage)) q_copy[1].fill_(10) self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) @@ -1017,14 +443,14 @@ def test_serialization_array_with_storage(self): def test_type_conversions(self): x = torch.randn(5, 5) self.assertIsInstance(x.float(), torch.FloatTensor) - self.assertIsInstance(x.cuda(), torch.cuda.DoubleTensor) + self.assertIsInstance(x.cuda().double(), torch.cuda.DoubleTensor) self.assertIsInstance(x.cuda().float(), torch.cuda.FloatTensor) self.assertIsInstance(x.cuda().float().cpu(), torch.FloatTensor) self.assertIsInstance(x.cuda().float().cpu().int(), torch.IntTensor) y = x.storage() self.assertIsInstance(y.float(), torch.FloatStorage) - self.assertIsInstance(y.cuda(), torch.cuda.DoubleStorage) + self.assertIsInstance(y.cuda().double(), torch.cuda.DoubleStorage) self.assertIsInstance(y.cuda().float(), torch.cuda.FloatStorage) self.assertIsInstance(y.cuda().float().cpu(), torch.FloatStorage) self.assertIsInstance(y.cuda().float().cpu().int(), torch.IntStorage) @@ -2147,6 +1573,7 @@ def _test_multinomial_invalid_probs_cuda(probs): except RuntimeError as e: return e + @slowTest @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ don't support multiprocessing with spawn start method") @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows') @@ -2214,6 +1641,7 @@ def test_nvtx(self): torch.cuda.nvtx.mark("bar") torch.cuda.nvtx.range_pop() + @default_floating_dtype(torch.double) def test_bincount_ext(self): # ensure CUDA code coverage input_size = (5000,) @@ -2406,94 +1834,5 @@ def worker(rank): """]) -def load_ignore_file(): - from os.path import join, dirname - global ignores - path = join(dirname(__file__), 'data', 'test_cuda_ignores.txt') - with open(path, 'r') as f: - ignores = {l for l in f.read().splitlines() if not l.startswith('#')} - - -def generate_tests(): - for decl in tests: - for t in types: - tensor = t() - - # Default values - desc = '' - type_subset = types - no_inplace = False - decorator = None - if len(decl) == 3: - name, constr, arg_constr = decl - elif len(decl) == 4: - name, constr, arg_constr, desc = decl - elif len(decl) == 5: - name, constr, arg_constr, desc, type_subset = decl - elif len(decl) == 6: - name, constr, arg_constr, desc, type_subset, no_inplace = decl - elif len(decl) == 7: - name, constr, arg_constr, desc, type_subset, no_inplace, decorator = decl - - if t not in type_subset: - continue - if TEST_WITH_ROCM and decorator is not None: - if (isinstance(decorator, str)): - tensor_type_name = str(t.__name__) - decorator_list = decorator.split(":") - skip_type_list = decorator_list[1].split(",") - if (("ByteTensor" in skip_type_list) and tensor_type_name == "ByteTensor") \ - or (("CharTensor" in skip_type_list) and tensor_type_name == "CharTensor") \ - or (("DoubleTensor" in skip_type_list) and tensor_type_name == "DoubleTensor") \ - or (("FloatTensor" in skip_type_list) and tensor_type_name == "FloatTensor") \ - or (("HalfTensor" in skip_type_list) and tensor_type_name == "HalfTensor") \ - or (("IntTensor" in skip_type_list) and tensor_type_name == "IntTensor") \ - or (("LongTensor" in skip_type_list) and tensor_type_name == "LongTensor") \ - or (("ShortTensor" in skip_type_list) and tensor_type_name == "ShortTensor"): - decorator = skipIfRocm - else: - decorator = None - elif ((not TEST_WITH_ROCM) and (decorator is not None)): - if (isinstance(decorator, str)): - decorator = None - - precision = custom_precision.get(name, TestCuda.precision) - if is_half(t): - precision = custom_half_precision.get(name, precision) - - for inplace in (True, False): - if inplace and no_inplace: - continue - if inplace: - name_inner = name + '_' - else: - name_inner = name - - if t != torch.HalfTensor and not hasattr(tensor, name_inner): - # torch.HalfTensor doesn't support most operations, - # but we use torch.FloatTensor as cpu baseline - continue - full_name = '{}.{}'.format(tensor.type(), name_inner) - if full_name in ignores: - continue - - test_name = 'test_' + t.__name__ + '_' + name_inner - if desc: - test_name += '_' + desc - - assert not hasattr(TestCuda, test_name), "Duplicated test name: " + test_name - - test_fn = compare_cpu_gpu(constr, arg_constr, name_inner, t, precision) - - if decorator is not None: - test_fn = decorator(test_fn) - - setattr(TestCuda, test_name, test_fn) - - if __name__ == '__main__': - if TEST_CUDA: - load_ignore_file() - generate_tests() - run_tests() diff --git a/test/test_dataloader.py b/test/test_dataloader.py index d450e6005dc85..ca8c9021bb902 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -17,7 +17,7 @@ from torch._utils import ExceptionWrapper from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, PY3, IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, - load_tests, TEST_WITH_TSAN) + load_tests, TEST_WITH_TSAN, default_floating_dtype) try: import psutil @@ -1485,6 +1485,7 @@ def check_len(dl, expected): check_len(DataLoader(self.dataset, batch_size=3), 34) @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") + @default_floating_dtype(torch.double) def test_numpy_scalars(self): import numpy as np diff --git a/test/test_dist_autograd_fork.py b/test/test_dist_autograd_fork.py index f64c13ec2e96e..e57b4c3803b57 100644 --- a/test/test_dist_autograd_fork.py +++ b/test/test_dist_autograd_fork.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 from __future__ import absolute_import, division, print_function, unicode_literals +import torch + +# dist_autograd_fork tests use double as the default dtype +torch.set_default_dtype(torch.double) + from dist_autograd_test import DistAutogradTest from common_distributed import MultiProcessTestCase from common_utils import run_tests diff --git a/test/test_dist_autograd_spawn.py b/test/test_dist_autograd_spawn.py index 409e1ac08ba7a..e8d8af119aaf0 100644 --- a/test/test_dist_autograd_spawn.py +++ b/test/test_dist_autograd_spawn.py @@ -3,11 +3,11 @@ from dist_autograd_test import DistAutogradTest from common_distributed import MultiProcessTestCase -from common_utils import run_tests +from common_utils import TEST_WITH_ASAN, run_tests import unittest -@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/27157") +@unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues") class DistAutogradTestWithSpawn(MultiProcessTestCase, DistAutogradTest): def setUp(self): diff --git a/test/test_distributions.py b/test/test_distributions.py index 09073e970090b..8b3265a2b7f2e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -30,6 +30,11 @@ from random import shuffle import torch + +# TODO: remove this global setting +# Distributions tests use double as the default dtype +torch.set_default_dtype(torch.double) + from torch._six import inf from common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests from common_cuda import TEST_CUDA diff --git a/test/test_jit.py b/test/test_jit.py index a9903e5c37ab4..a21e17ced3716 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1,5 +1,10 @@ # -*- coding: UTF-8 -*- from __future__ import division +import torch + +# TODO: remove this global setting +# JIT tests use double as the default dtype +torch.set_default_dtype(torch.double) # Torch from torch import Tensor @@ -10,7 +15,6 @@ from torch.jit.frontend import NotSupportedError from torch.onnx import OperatorExportTypes from torch.testing import FileCheck -import torch import torch.cuda import torch.jit import torch.jit._logging @@ -26,7 +30,8 @@ # Testing utils from common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \ - freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName + freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName, skipIfCompiledWithoutNumpy, \ + default_floating_dtype from jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ _trace, enable_cpu_fuser_if, enable_profiling_mode, do_input_map, \ execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ @@ -89,24 +94,6 @@ PY35 = sys.version_info >= (3, 5) -def default_tensor_type(type): - type_str = torch.typename(type) - - def decorator(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - old_type = torch.Tensor().type() - torch.set_default_tensor_type(type_str) - try: - return fn(*args, **kwargs) - finally: - torch.set_default_tensor_type(old_type) - - return wrapper - - return decorator - - def LSTMCellF(input, hx, cx, *params): return LSTMCell(input, (hx, cx), *params) @@ -1793,6 +1780,7 @@ def to_tensor(x, y): x, y = torch.randn(2, 2), torch.randn(1, 10) self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y)) + @skipIfCompiledWithoutNumpy def test_trace_warn(self): def fn(x): int(x) # Warning 1. @@ -3442,6 +3430,46 @@ def forward(self, x): mod = torch.jit.script(MyMod()) FileCheck().check_dag("NamedTuple").check_dag("Exception").run(mod.forward.graph) + def test_eval_python(self): + def _test(m): + self.assertTrue(m(torch.ones(2, 2))) + self.assertTrue(m.training) + self.assertTrue(m._c._get_attribute('training')) + + m.eval() + + self.assertFalse(m.training) + self.assertFalse(m._c._get_attribute('training')) + self.assertFalse(m(torch.ones(2, 2))) + + if not PY2: + buffer = io.BytesIO() + torch.jit.save(m, buffer) + buffer.seek(0) + + loaded = torch.jit.load(buffer) + + self.assertFalse(loaded.training) + self.assertFalse(loaded._c._get_attribute('training')) + + class M(nn.Module): + def __init__(self): + super(M, self).__init__() + + def forward(self, x): + return self.training + + class OldM(torch.jit.ScriptModule): + def __init__(self): + super(OldM, self).__init__() + + @torch.jit.script_method + def forward(self, x): + return self.training + + _test(torch.jit.script(M())) + _test(OldM()) + def test_inherit_method(self): class A(torch.jit.ScriptModule): def __init__(self): @@ -3954,7 +3982,7 @@ def forward(self, xyz): bytesio = io.BytesIO(buffer) scripted = torch.jit.load(bytesio) - fc = FileCheck().check(':6:11') + fc = FileCheck().check(':7:11') fc.run(scripted.graph) fc.run(str(scripted.graph)) @@ -4002,7 +4030,7 @@ def forward(self): _, lineno = inspect.getsourcelines(FooTest2) - with self.assertRaisesRegex(torch._C.JITException, 'test_jit.py:{}'.format(lineno + 3)): + with self.assertRaisesRegex(torch.jit.Error, 'test_jit.py:{}'.format(lineno + 3)): ft = FooTest2() loaded = self.getExportImportCopy(ft) loaded() @@ -4597,7 +4625,7 @@ def test_method_on_number(self): def func(): c = 1 return c.add(1) - with self.assertRaisesRegex(RuntimeError, 'Cannot call methods on numbers'): + with self.assertRaisesRegex(RuntimeError, 'nonexistent attribute or method'): torch.jit.script(func) # testing implicit conversion of tensors to scalars to match function arguments @@ -5006,18 +5034,11 @@ def def_in_one_branch(x, z): a = torch.rand(2, 3) with enable_profiling_mode(): - # the first calls are profiled - def_in_one_branch(a, False) - # check prim::profile are inserted - profiled_graph_str = str(def_in_one_branch.graph_for(a, True)) + # the first call is profiled + profiled_graph_str = str(def_in_one_branch.graph_for(a, False)) FileCheck().check_count("prim::profile", 4).run(profiled_graph_str) + # the second call is optimized def_in_one_branch(a, False) - def_in_one_branch(a, False) - # this call is optimized for - # the given shape of (2, 3) - def_in_one_branch(a, False) - # change shape to (3) - # so we go down a bailout path a = torch.ones(3) # check prim::BailOuts are inserted bailout_graph_str = str(def_in_one_branch.graph_for(a, True)) @@ -7854,7 +7875,7 @@ def test_binary_op_shape(self): self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0) self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3) - @default_tensor_type(torch.FloatTensor) + @default_floating_dtype(torch.float) def test_wrapped_number(self): # Scalar's get converted to 'wrapped' tensors of default tensor type. # Wrapped tensors behave differently in certain promotion operations: @@ -8322,11 +8343,12 @@ def __init__(self): @torch.jit.export def __getstate__(self): - return (3,) + return (3, self.training) @torch.jit.export def __setstate__(self, state): self.a = state[0] + self.training = state[1] def forward(self, x): return x + self.a @@ -8352,11 +8374,12 @@ def __init__(self): @torch.jit.export def __getstate__(self): - return (3,) + return (3, self.training) @torch.jit.export def __setstate__(self, state): self.a = state[0] + self.training = state[1] def forward(self, x): return x + self.a @@ -8819,6 +8842,7 @@ def forward(self, rep): m = M2() m(torch.zeros(4, 3)) + @skipIfCompiledWithoutNumpy def test_pack_padded_pad_packed_trace(self): from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence T, B, C = 3, 5, 7 @@ -8939,6 +8963,7 @@ def foo(a): v = torch.rand(10, 3) self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v)) + @skipIfCompiledWithoutNumpy def test_rnn_trace_override(self): from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence num_layers = 3 @@ -10672,7 +10697,7 @@ def forward(self, x): ReassignSelfRHS() def test_unknown_builtin(self): - with self.assertRaisesRegex(RuntimeError, 'Unknown builtin op'): + with self.assertRaisesRegex(RuntimeError, 'nonexistent attribute or method'): @torch.jit.script def unknown_builtin(x): return x.splork(3) @@ -11970,12 +11995,12 @@ def f(x): self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True) - with self.assertRaisesRegex(RuntimeError, "Unknown attribute to named tuple"): + with self.assertRaisesRegex(RuntimeError, "nonexistent attribute"): @torch.jit.script def g1(x): return x.max(dim=1).unknown_symbol - with self.assertRaisesRegex(RuntimeError, "Getting attributes of tuples is not supported"): + with self.assertRaisesRegex(RuntimeError, "nonexistent attribute"): @torch.jit.script def g2(x): print((x, x, x).__doc__) @@ -13480,12 +13505,13 @@ def __init__(self, number): @torch.jit.script_method def __getstate__(self): - return (self.buffer1, self.buffer2, 74) + return (self.buffer1, self.buffer2, 74, self.training) @torch.jit.script_method def __setstate__(self, state): self.buffer1 = state[0] + 10 self.buffer2 = state[1] + 10 + self.training = state[3] class M(torch.jit.ScriptModule): @@ -13500,13 +13526,14 @@ def __init__(self, number, submodule): @torch.jit.script_method def __getstate__(self): - return (self.buffer1, self.buffer2, 74, self.submodule) + return (self.buffer1, self.buffer2, 74, self.submodule, self.training) @torch.jit.script_method def __setstate__(self, state): self.buffer1 = state[0] + 10 self.buffer2 = state[1] + 10 self.submodule = state[3] + self.training = state[4] with TemporaryFileName() as fname: m = M(23, submodule=Root(99)) @@ -13537,12 +13564,13 @@ def forward(self): @torch.jit.export def __getstate__(self): - return 5 + return (5, self.training) @torch.jit.export def __setstate__(self, state): - self.buffer1 = torch.ones(2, 2) + state + self.buffer1 = torch.ones(2, 2) + state[0] self.buffer2 = torch.ones(2, 2) + 10 + self.training = state[1] with TemporaryFileName() as fname: m = torch.jit.script(NoArgState()) @@ -14636,7 +14664,7 @@ def forward(self, x, use_ignore_path): buffer.seek(0) loaded = torch.jit.load(buffer) - with self.assertRaisesRegex(torch._C.JITException, "annotated to be ignored and cannot be run"): + with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): loaded(torch.tensor(.5), True) def test_module_error(self): @@ -15168,12 +15196,13 @@ def __init__(self): @torch.jit.export def __getstate__(self): - return (self.tensor,) + return (self.tensor, self.training) @torch.jit.export def __setstate__(self, state): - # type: (Tuple[Tensor]) + # type: (Tuple[Tensor, bool]) self.tensor = state[0] + self.training = state[1] def forward(self, x): return x + self.tensor @@ -15253,7 +15282,7 @@ def __init__(self, in_features, out_features): @torch.jit.export def __getstate__(self): - return torch.ops.quantized.linear_unpack(self._packed_weight)[0] + return (torch.ops.quantized.linear_unpack(self._packed_weight)[0], self.training) def forward(self): return self._packed_weight @@ -15261,7 +15290,8 @@ def forward(self): @torch.jit.export def __setstate__(self, state): self._packed_weight.set_( - torch.ops.quantized.linear_prepack(state)) + torch.ops.quantized.linear_prepack(state[0])) + self.training = state[1] @property def weight(self): @@ -19265,10 +19295,10 @@ def set_non_initialized(self, y): self.bar = y # can't assign to non-initialized attr def test_schema_human_readable(self): - """ + """ Make sure that the schema is human readable, ie the mode parameter should read "nearest" instead of being displayed in octal - aten::__interpolate(Tensor input, int? size=None, float[]? scale_factor=None, - str mode='\156\145\141\162\145\163\164', bool? align_corners=None) -> (Tensor): + aten::__interpolate(Tensor input, int? size=None, float[]? scale_factor=None, + str mode='\156\145\141\162\145\163\164', bool? align_corners=None) -> (Tensor): Expected a value of type 'Optional[int]' for argument 'size' but instead found type 'Tensor'. """ with self.assertRaisesRegex(RuntimeError, "nearest"): @@ -19723,8 +19753,46 @@ def wrong4(x): # type: (OneTwoWrong) -> int return as_interface(x) - # TODO test: interface-interface class-interface inheritance errors, - # NamedTuple inheritance errors + # Test interface/class python assignment + class TestPyAssign(nn.Module): + def __init__(self): + super(TestPyAssign, self).__init__() + self.proxy_mod = Foo() + + def forward(self, x): + return self.proxy_mod.two(x) + + TestPyAssign.__annotations__ = {'proxy_mod': OneTwo} + + input = torch.rand(3, 4) + scripted_pyassign_mod = torch.jit.script(TestPyAssign()) + imported_mod = self.getExportImportCopy(scripted_pyassign_mod) + self.assertEqual(scripted_pyassign_mod(input), imported_mod(input)) + + class TestPyAssignError(nn.Module): + def __init__(self, obj): + super(TestPyAssignError, self).__init__() + self.proxy_mod = obj + + def forward(self, x): + return self.proxy_mod.two(x) + + TestPyAssignError.__annotations__ = {'proxy_mod': OneTwoThree} + + with self.assertRaisesRegex(RuntimeError, + "is not compatible with interface __torch__"): + torch.jit.script(TestPyAssignError(Foo())) + + # test pure python object assignment to interface fails + class PyClass(object): + def __init__(self): + pass + + with self.assertRaisesRegex(RuntimeError, + "the value is not a TorchScript compatible type"): + torch.jit.script(TestPyAssignError(PyClass())) + # TODO test: interface-interface class-interface inheritance errors, + # NamedTuple inheritance errors def test_overloaded_fn(self): @torch.jit.script @@ -19876,7 +19944,7 @@ def call(): # noqa: E306 for func in ops: self.checkScript(func, ()) - with self.assertRaisesRegex(RuntimeError, "nonexistent attribute __add__. Did you forget to initialize it"): + with self.assertRaisesRegex(RuntimeError, "nonexistent attribute"): @torch.jit.script def test(): return Foo(torch.tensor(1)) + Foo(torch.tensor(1)) diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index b7ab2b89bdfd4..3911f601bcac4 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -13,7 +13,7 @@ import torch.utils.hooks from torch.nn import Parameter from common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN, - load_tests, slowTest) + load_tests, slowTest, TEST_WITH_TSAN) # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -221,6 +221,7 @@ def _has_shm_files(self): return False +@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment") class TestMultiprocessing(TestCase): def tearDown(self): diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 4377bdbd8f92f..015ab594fa08a 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1,5 +1,5 @@ import unittest -from common_utils import TestCase, run_tests, TEST_NUMPY +from common_utils import TestCase, run_tests, TEST_NUMPY, default_floating_dtype from common_cuda import TEST_CUDA from collections import namedtuple, OrderedDict import itertools @@ -275,6 +275,7 @@ def test_no_multiprocessing_support(self): with self.assertRaisesRegex(RuntimeError, "NYI"): ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor) + @default_floating_dtype(torch.double) def test_big_tensor_repr(self): def check_repr(named_tensor): unnamed_tensor = named_tensor.rename(None) diff --git a/test/test_nn.py b/test/test_nn.py index e6dd049ddba71..96f7a6ee509d9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -16,6 +16,11 @@ import threading import torch + +# TODO: remove this global setting +# NN tests use double as the default dtype +torch.set_default_dtype(torch.double) + from torch._six import inf, nan import torch.backends.cudnn as cudnn import torch.nn as nn @@ -31,7 +36,7 @@ from torch.nn.parallel._functions import Broadcast from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \ - get_function_arglist, load_tests + get_function_arglist, load_tests, default_floating_dtype from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \ module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \ @@ -187,6 +192,7 @@ def test_to(self): padded, lengths, enforce_sorted=enforce_sorted).cpu() self.assertIs(a, a.to('cpu')) + self.assertIs(a, a.cpu()) self.assertIs(a, a.to('cpu', dtype=torch.int32)) self.assertEqual(a.long(), a.to(torch.int64)) @@ -194,6 +200,7 @@ def test_to(self): for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: b = a.cuda(device=cuda) self.assertIs(b, b.to(cuda)) + self.assertIs(b, b.cuda()) self.assertEqual(a, b.to('cpu')) self.assertEqual(b, a.to(cuda)) self.assertEqual(a, b.to('cpu', dtype=torch.int32)) @@ -201,24 +208,6 @@ def test_to(self): self.assertEqual(b.long(), b.to(dtype=torch.int64)) -def default_tensor_type(type): - type_str = torch.typename(type) - - def decorator(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - old_type = torch.Tensor().type() - torch.set_default_tensor_type(type_str) - try: - return fn(*args, **kwargs) - finally: - torch.set_default_tensor_type(old_type) - - return wrapper - - return decorator - - def _assertGradAndGradgradChecks(test_case, apply_fn, inputs): # call assert function rather than returning a bool since it's nicer # if we get whether this failed on the gradcheck or the gradgradcheck. @@ -5494,7 +5483,7 @@ def compare_cpu_gpu(outputs_cpu, outputs_gpu): compare_cpu_gpu(outputs_cpu, outputs_gpu) @unittest.skipIf(not TEST_CUDNN, "needs cudnn") - @default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented + @default_floating_dtype(torch.float) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented def test_RNN_cpu_vs_cudnn_no_dropout(self): self._test_RNN_cpu_vs_cudnn(0) @@ -5520,7 +5509,7 @@ def test_RNN_cudnn_weight_norm(self): self.assertEqual(m(input), expected_output) @unittest.skipIf(not (TEST_CUDNN and (TEST_CUDNN_VERSION if TEST_CUDNN_VERSION else 0) >= 5103), "needs cudnn >= 5.1") - @default_tensor_type(torch.FloatTensor) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented + @default_floating_dtype(torch.float) # FIXME: just until torch.cuda.DoubleTensor.sum() implemented def test_RNN_cpu_vs_cudnn_with_dropout(self): # Because of dropout randomness, can only compare dropout=0 and dropout=1 self._test_RNN_cpu_vs_cudnn(1) @@ -8987,7 +8976,7 @@ def test_device_mask(self, device): packed = rnn_utils.pack_padded_sequence( padded, lengths, enforce_sorted=enforce_sorted) self.assertFalse(packed.is_cuda) - packed = packed.to(device=device) + packed = packed.to(device) self.assertTrue(packed.is_cuda) unpacked, _ = rnn_utils.pad_packed_sequence(packed) self.assertEqual(unpacked.type(), cuda_type_str) diff --git a/test/test_optim.py b/test/test_optim.py index 1bbd88c152cf8..8f93732fe023b 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -3,7 +3,6 @@ import unittest import functools from copy import deepcopy -from bisect import bisect_right import torch from torch._six import inf import torch.optim as optim @@ -11,10 +10,11 @@ from torch.optim import SGD from torch.autograd import Variable from torch import sparse -from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \ - ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \ - CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR -from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests +from torch.optim.lr_scheduler import LambdaLR, StepLR, \ + MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ + _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR +from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ + skipIfRocm # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -28,7 +28,7 @@ def rosenbrock(tensor): def drosenbrock(tensor): x, y = tensor - return torch.DoubleTensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2))) + return torch.Tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2))) class TestOptim(TestCase): @@ -62,12 +62,12 @@ def eval(params, sparse_grad, w): if w: i = torch.LongTensor([[0, 0]]) x = grad[0] - v = torch.DoubleTensor([x / 4., x - x / 4.]) + v = torch.Tensor([x / 4., x - x / 4.]) else: i = torch.LongTensor([[1, 1]]) y = grad[1] - v = torch.DoubleTensor([y - y / 4., y / 4.]) - x = sparse.DoubleTensor(i, v, torch.Size([2])) + v = torch.Tensor([y - y / 4., y / 4.]) + x = sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype) with torch.no_grad(): if sparse_grad: params.grad = x @@ -343,6 +343,8 @@ def test_sparse_adam(self): with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0)) + # ROCm precision is too low to pass this test + @skipIfRocm def test_adadelta(self): self._test_basic_cases( lambda weight, bias: optim.Adadelta([weight, bias]) @@ -499,36 +501,6 @@ def __eq__(self, other): return False -class LegacyStepLR(StepLR): - def get_lr(self): - return [base_lr * self.gamma ** (self.last_epoch // self.step_size) - for base_lr in self.base_lrs] - - -class LegacyMultiStepLR(MultiStepLR): - def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): - self.milestones = sorted(milestones) - self.gamma = gamma - super(MultiStepLR, self).__init__(optimizer, last_epoch) - - def get_lr(self): - return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) - for base_lr in self.base_lrs] - - -class LegacyExponentialLR(ExponentialLR): - def get_lr(self): - return [base_lr * self.gamma ** self.last_epoch - for base_lr in self.base_lrs] - - -class LegacyCosineAnnealingLR(CosineAnnealingLR): - def get_lr(self): - return [self.eta_min + (base_lr - self.eta_min) * - (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 - for base_lr in self.base_lrs] - - class TestLRScheduler(TestCase): def setUp(self): super(TestLRScheduler, self).setUp() @@ -571,7 +543,7 @@ def test_old_pattern_warning(self): self.assertTrue(len(ws) == 0, "No warning should be raised") def old_pattern(): - for e in range(epochs): + for _ in range(epochs): scheduler.step() self.opt.step() @@ -585,8 +557,8 @@ def test_old_pattern_warning_with_arg(self): self.assertTrue(len(ws) == 0, "No warning should be raised") def old_pattern2(): - for e in range(epochs): - scheduler.step(e) + for _ in range(epochs): + scheduler.step() self.opt.step() self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate') @@ -602,7 +574,7 @@ def test_old_pattern_warning_resuming(self): self.assertTrue(len(ws) == 0, "No warning should be raised") def old_pattern(): - for e in range(epochs): + for _ in range(epochs): scheduler.step() self.opt.step() @@ -619,13 +591,13 @@ def test_old_pattern_warning_resuming_with_arg(self): self.assertTrue(len(ws) == 0, "No warning should be raised") def old_pattern2(): - for e in range(epochs): - scheduler.step(e) + for _ in range(epochs): + scheduler.step() self.opt.step() self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate') - def test_old_pattern_warning_with_overriden_optim_step(self): + def test_old_pattern_warning_with_overridden_optim_step(self): epochs = 35 for i, group in enumerate(self.opt.param_groups): group['initial_lr'] = 0.01 @@ -635,7 +607,7 @@ def test_old_pattern_warning_with_overriden_optim_step(self): scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) self.assertTrue(len(ws) == 0, "No warning should be raised") - # emulate use-case with optimizer.step overriden + # emulate use-case with optimizer.step overridden import types old_step = self.opt.step @@ -647,8 +619,8 @@ def new_step(o, *args, **kwargs): self.opt.step = types.MethodType(new_step, self.opt) def old_pattern2(): - for e in range(epochs): - scheduler.step(e) + for _ in range(epochs): + scheduler.step() self.opt.step() self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate') @@ -662,7 +634,7 @@ def test_new_pattern_no_warning(self): with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always") # allow any warning to be raised - for e in range(epochs): + for _ in range(epochs): self.opt.step() scheduler.step() self.assertTrue(len(ws) == 0, "No warning should be raised") @@ -676,19 +648,19 @@ def test_new_pattern_no_warning_with_arg(self): with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always") # allow any warning to be raised - for e in range(epochs): + for _ in range(epochs): self.opt.step() - scheduler.step(e) + scheduler.step() self.assertTrue(len(ws) == 0, "No warning should be raised") - def test_new_pattern_no_warning_with_overriden_optim_step(self): + def test_new_pattern_no_warning_with_overridden_optim_step(self): epochs = 35 with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always") # allow any warning to be raised scheduler = StepLR(self.opt, gamma=0.1, step_size=3) self.assertTrue(len(ws) == 0, "No warning should be raised") - # emulate use-case with optimizer.step overriden + # emulate use-case with optimizer.step overridden import types old_step = self.opt.step @@ -706,6 +678,23 @@ def new_pattern(): self.assertWarnsRegex(new_pattern, r'`optimizer.step\(\)` has been overridden') + def _test_lr_is_constant_for_constant_epoch(self, scheduler): + l = [] + + # warnings.filterwarnings("ignore", category=DeprecationWarning) + for _ in range(10): + scheduler.step(2) + l.append(self.opt.param_groups[0]['lr']) + self.assertAlmostEqual(min(l), max(l)) + + def test_step_lr_is_constant_for_constant_epoch(self): + scheduler = StepLR(self.opt, 2) + self._test_lr_is_constant_for_constant_epoch(scheduler) + + def test_exponential_lr_is_constant_for_constant_epoch(self): + scheduler = ExponentialLR(self.opt, gamma=0.9) + self._test_lr_is_constant_for_constant_epoch(scheduler) + def test_step_lr(self): # lr = 0.05 if epoch < 3 # lr = 0.005 if 30 <= epoch < 6 @@ -716,6 +705,25 @@ def test_step_lr(self): scheduler = StepLR(self.opt, gamma=0.1, step_size=3) self._test(scheduler, targets, epochs) + def test_get_last_lr_step_lr(self): + from torch.nn import Parameter + epochs = 10 + optimizer = torch.optim.SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1) + targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]] + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) + self._test_get_last_lr(scheduler, targets, epochs) + + def test_get_last_lr_multi_step_lr(self): + # lr = 0.05 if epoch < 2 + # lr = 0.005 if 2 <= epoch < 5 + # lr = 0.0005 if 5 <= epoch < 9 + # lr = 0.00005 if 9 <= epoch + epochs = 10 + single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1 + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) + self._test_get_last_lr(scheduler, targets, epochs) + def test_multi_step_lr(self): # lr = 0.05 if epoch < 2 # lr = 0.005 if 2 <= epoch < 5 @@ -744,28 +752,28 @@ def test_cos_anneal_lr(self): scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) self._test(scheduler, targets, epochs) - def test_legacy_step_lr(self): + def test_closed_form_step_lr(self): scheduler = StepLR(self.opt, gamma=0.1, step_size=3) - legacy_scheduler = LegacyStepLR(self.opt, gamma=0.1, step_size=3) - self._test_against_legacy(scheduler, legacy_scheduler, 20) + closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3) + self._test_against_closed_form(scheduler, closed_form_scheduler, 20) - def test_legacy_multi_step_lr(self): + def test_closed_form_multi_step_lr(self): scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) - legacy_scheduler = LegacyMultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) - self._test_against_legacy(scheduler, legacy_scheduler, 20) + closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) + self._test_against_closed_form(scheduler, closed_form_scheduler, 20) - def test_legacy_exp_lr(self): + def test_closed_form_exp_lr(self): scheduler = ExponentialLR(self.opt, gamma=0.9) - legacy_scheduler = LegacyExponentialLR(self.opt, gamma=0.9) - self._test_against_legacy(scheduler, legacy_scheduler, 20) + closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9) + self._test_against_closed_form(scheduler, closed_form_scheduler, 20) - def test_legacy_cos_anneal_lr(self): + def test_closed_form_cos_anneal_lr(self): eta_min = 1e-10 epochs = 20 T_max = 5 scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) - legacy_scheduler = LegacyCosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) - self._test_against_legacy(scheduler, legacy_scheduler, epochs) + closed_form_scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) + self._test_against_closed_form(scheduler, closed_form_scheduler, epochs) def test_reduce_lr_on_plateau1(self): epochs = 10 @@ -847,6 +855,145 @@ def test_reduce_lr_on_plateau8(self): threshold=0.1, patience=5, cooldown=5) self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) + def test_compound_step_and_multistep_lr(self): + epochs = 10 + schedulers = [None] * 2 + schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) + schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) + targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]] + self._test(schedulers, targets, epochs) + + def test_compound_step_and_exp_lr(self): + epochs = 10 + schedulers = [None] * 2 + single_targets = [0.05 * (0.9 ** x) for x in range(3)] + single_targets += [0.005 * (0.9 ** x) for x in range(3, 6)] + single_targets += [0.0005 * (0.9 ** x) for x in range(6, 9)] + single_targets += [0.00005 * (0.9 ** x) for x in range(9, 12)] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) + schedulers[1] = ExponentialLR(self.opt, gamma=0.9) + self._test(schedulers, targets, epochs) + + def test_compound_exp_and_multistep_lr(self): + epochs = 10 + schedulers = [None] * 2 + single_targets = [0.05 * (0.9 ** x) for x in range(2)] + single_targets += [0.005 * (0.9 ** x) for x in range(2, 5)] + single_targets += [0.0005 * (0.9 ** x) for x in range(5, 9)] + single_targets += [0.00005 * (0.9 ** x) for x in range(9, 11)] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) + schedulers[1] = ExponentialLR(self.opt, gamma=0.9) + self._test(schedulers, targets, epochs) + + def test_compound_cosanneal_and_step_lr(self): + epochs = 10 + eta_min = 1e-10 + single_targets = [eta_min + (0.05 - eta_min) * + (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs)] + single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + schedulers = [None] * 2 + schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) + schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) + self._test(schedulers, targets, epochs) + + def test_compound_cosanneal_and_multistep_lr(self): + epochs = 10 + eta_min = 1e-10 + single_targets = [eta_min + (0.05 - eta_min) * + (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs)] + multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] + single_targets = [x * y for x, y in zip(single_targets, multipliers)] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + schedulers = [None] * 2 + schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) + schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) + self._test(schedulers, targets, epochs) + + def test_compound_cosanneal_and_exp_lr(self): + epochs = 10 + eta_min = 1e-10 + single_targets = [eta_min + (0.05 - eta_min) * + (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs)] + multipliers = [0.1 ** i for i in range(epochs)] + single_targets = [x * y for x, y in zip(single_targets, multipliers)] + targets = [single_targets, list(map(lambda x: x * epochs, single_targets))] + schedulers = [None] * 2 + schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) + schedulers[1] = ExponentialLR(self.opt, gamma=0.1) + self._test(schedulers, targets, epochs) + + def test_compound_reduce_lr_on_plateau1(self): + epochs = 10 + for param_group in self.opt.param_groups: + param_group['lr'] = 0.5 + single_targets = [0.5] * 20 + multipliers = [0.1 ** (i // 3) for i in range(20)] + single_targets = [x * y for x, y in zip(multipliers, single_targets)] + targets = [single_targets] + targets = targets[1:] # test runs step before checking lr + metrics = [10 - i * 0.0167 for i in range(20)] + schedulers = [None, None] + schedulers[0] = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min', + threshold=0.01, patience=5, cooldown=5) + schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) + self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) + + def test_compound_reduce_lr_on_plateau2(self): + epochs = 22 + for param_group in self.opt.param_groups: + param_group['lr'] = 0.5 + single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 + multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10 + single_targets = [x * y for x, y in zip(single_targets, multipliers)] + targets = [single_targets] + targets = targets[1:] # test runs step before checking lr + metrics = [10 - i * 0.0165 for i in range(22)] + schedulers = [None] * 2 + schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', + mode='min', threshold=0.1) + schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12]) + self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) + + def test_compound_reduce_lr_on_plateau3(self): + epochs = 22 + for param_group in self.opt.param_groups: + param_group['lr'] = 0.5 + single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4 + multipliers = [0.1 ** i for i in range(epochs)] + single_targets = [x * y for x, y in zip(multipliers, single_targets)] + targets = [single_targets] + targets = targets[1:] # test runs step before checking lr + metrics = [-0.8] * 2 + [-0.234] * 20 + schedulers = [None, None] + schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5, + threshold_mode='abs') + schedulers[1] = ExponentialLR(self.opt, gamma=0.1) + self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) + + def test_compound_reduce_lr_on_plateau4(self): + epochs = 20 + for param_group in self.opt.param_groups: + param_group['lr'] = 0.05 + epochs = 10 + eta_min = 1e-10 + single_targets = [eta_min + (0.05 - eta_min) * + (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs)] + targets = [single_targets] + targets = targets[1:] # test runs step before checking lr + metrics = [1.5 * (1.025 ** i) for i in range(20)] # 1.025 > 1.1**0.25 + schedulers = [None, None] + schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=3, + threshold_mode='rel', threshold=0.1) + schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) + self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) + def test_cycle_lr_invalid_mode(self): with self.assertRaises(ValueError): scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS") @@ -1220,17 +1367,30 @@ def _check_scheduler_state_dict(self, constr, constr2, epochs=10): for key in scheduler.__dict__.keys(): if key != 'optimizer': self.assertAlmostEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) - self.assertAlmostEqual(scheduler.get_lr(), scheduler_copy.get_lr()) + self.assertAlmostEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) + + def _test_get_last_lr(self, schedulers, targets, epochs=10): + if isinstance(schedulers, _LRScheduler): + schedulers = [schedulers] + for epoch in range(epochs): + result = [scheduler.get_last_lr() for scheduler in schedulers] + [scheduler.step() for scheduler in schedulers] + target = [[t[epoch] for t in targets]] * len(schedulers) + # print(target) + for t, r in zip(target, result): + self.assertAlmostEqual(target, result, + msg='LR is wrong in epoch {}: expected {}, got {}'.format( + epoch, t, r), delta=1e-5) def _test(self, schedulers, targets, epochs=10): if isinstance(schedulers, _LRScheduler): schedulers = [schedulers] for epoch in range(epochs): - [scheduler.step(epoch) for scheduler in schedulers] for param_group, target in zip(self.opt.param_groups, targets): self.assertAlmostEqual(target[epoch], param_group['lr'], msg='LR is wrong in epoch {}: expected {}, got {}'.format( epoch, target[epoch], param_group['lr']), delta=1e-5) + [scheduler.step() for scheduler in schedulers] def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10): for index, epoch in enumerate(torch.arange(0, epochs, 0.1)): @@ -1249,15 +1409,16 @@ def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epoc msg='LR is wrong in epoch {}: expected {}, got {}'.format( epoch, target[index], param_group['lr']), delta=1e-5) - def _test_against_legacy(self, scheduler, legacy_scheduler, epochs=10): + def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10): self.setUp() targets = [] + # warnings.filterwarnings("ignore", category=DeprecationWarning) for epoch in range(epochs): - legacy_scheduler.step(epoch) + closed_form_scheduler.step(epoch) targets.append([group['lr'] for group in self.opt.param_groups]) self.setUp() for epoch in range(epochs): - scheduler.step(epoch) + scheduler.step() for i, param_group in enumerate(self.opt.param_groups): self.assertAlmostEqual(targets[epoch][i], param_group['lr'], msg='LR is wrong in epoch {}: expected {}, got {}'.format( @@ -1271,7 +1432,7 @@ def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, ve if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(metrics[epoch]) else: - scheduler.step(epoch) + scheduler.step() if verbose: print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr'])) for param_group, target in zip(self.opt.param_groups, targets): @@ -1281,7 +1442,6 @@ def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, ve def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False): for batch_num in range(batch_iterations): - scheduler.step(batch_num) if verbose: if 'momentum' in self.opt.param_groups[0].keys(): print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'], @@ -1308,6 +1468,7 @@ def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iteratio momentum_target[batch_num], param_group['momentum'], msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format( batch_num, momentum_target[batch_num], param_group['momentum']), delta=1e-5) + scheduler.step() def test_cosine_then_cyclic(self): # https://github.com/pytorch/pytorch/issues/21965 diff --git a/test/test_qat.py b/test/test_qat.py index 1ea91558ff343..14b7639a8404c 100644 --- a/test/test_qat.py +++ b/test/test_qat.py @@ -5,7 +5,7 @@ import torch from torch.nn import Conv2d, BatchNorm2d, ReLU -from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d +from torch.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d from torch.quantization.QConfig import default_qat_qconfig import torch.backends.mkldnn from common_utils import TestCase, run_tests @@ -99,9 +99,9 @@ def test_conv_bn_relu( ).to(dtype=torch.double) qat_op.apply(torch.quantization.disable_fake_quant) if freeze_bn: - qat_op.apply(torch.nn._intrinsic.qat.freeze_bn_stats) + qat_op.apply(torch.nn.intrinsic.qat.freeze_bn_stats) else: - qat_op.apply(torch.nn._intrinsic.qat.update_bn_stats) + qat_op.apply(torch.nn.intrinsic.qat.update_bn_stats) # align inputs and internal parameters input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True) diff --git a/test/test_quantization.py b/test/test_quantization.py index a72e1c8173915..d996ebd500283 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -3,14 +3,16 @@ import torch import torch.nn as nn import torch.nn.quantized as nnq -import torch.nn._intrinsic as nni -import torch.nn._intrinsic.quantized as nniq -import torch.nn._intrinsic.qat as nniqat +import torch.nn.intrinsic as nni +import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.qat as nniqat from torch.quantization import \ QConfig, QConfigDynamic, default_observer, default_weight_observer, get_observer_dict,\ quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \ quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \ - default_dynamic_qconfig, HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, RecordingObserver, QuantWrapper + default_dynamic_qconfig, HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver,\ + RecordingObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, \ + QuantWrapper from torch.quantization._quantize_script import quantize_script @@ -22,7 +24,8 @@ ModelWithFunctionals, \ test_only_eval_fn, test_only_train_fn, \ prepare_dynamic, convert_dynamic, SingleLayerLinearDynamicModel, \ - TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel + TwoLayerLinearModel, NestedModel, ResNetBase, LSTMDynamicModel, \ + ModelWithNoQconfigPropagation from common_quantization import AnnotatedTwoLayerLinearModel, AnnotatedNestedModel, \ AnnotatedSubNestedModel, AnnotatedCustomConfigNestedModel @@ -294,7 +297,7 @@ def test_resnet_base(self, qconfig): model = convert(model) def checkQuantized(model): - self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d) + self.assertEqual(type(model.module.conv1), nn.intrinsic.quantized.ConvReLU2d) self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) test_only_eval_fn(model, self.img_data) @@ -336,6 +339,11 @@ def checkQuantized(model): quantize_dynamic(model, qconfig_dict, inplace=True) checkQuantized(model) + # Test set qconfig + model = SingleLayerLinearDynamicModel() + quantize_dynamic(model, set([nn.Linear]), inplace=True) + checkQuantized(model) + def test_two_layers(self): r"""TwoLayerLinearModel has two Linear modules but we only quantize the second one `fc2`, and `fc1`is not quantized @@ -359,6 +367,10 @@ def checkQuantized(model): model = quantize_dynamic(TwoLayerLinearModel().eval(), qconfig_dict) checkQuantized(model) + # Test set API + model = quantize_dynamic(TwoLayerLinearModel().eval(), {'fc2'}) + checkQuantized(model) + def test_nested1(self): r"""Test quantization for nested model, top level 'fc3' and 'fc1' of submodule 'sub2', 'sub2.fc2' is not quantized @@ -385,6 +397,9 @@ def checkQuantized(model): model = quantize_dynamic(NestedModel().eval(), qconfig_dict) checkQuantized(model) + model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2.fc1'}) + checkQuantized(model) + def test_nested2(self): r"""Another test case for quantized, we will quantize all submodules of submodule sub2 @@ -412,6 +427,10 @@ def checkQuantized(model): model = quantize_dynamic(NestedModel().eval(), qconfig_dict) checkQuantized(model) + # Test set API + model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2'}) + checkQuantized(model) + def test_nested3(self): r"""More complicated nested test case with child qconfig overrides parent qconfig @@ -443,6 +462,10 @@ def checkQuantized(model): model = quantize_dynamic(NestedModel().eval(), qconfig_dynamic_dict) checkQuantized(model) + # Test set API + model = quantize_dynamic(NestedModel().eval(), {'fc3', 'sub2', 'sub2.fc1'}) + checkQuantized(model) + def test_type_match_rule(self): r"""Test quantization for nested model, top level 'fc3' and 'fc1' of submodule 'sub2', All 'torch.nn.Linear' modules are quantized @@ -911,144 +934,184 @@ class ObserverTest(QuantizationTestCase): @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), reduce_range=st.booleans()) - def test_minmax_observer(self, qdtype, qscheme, reduce_range): + def test_per_tensor_observers(self, qdtype, qscheme, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric: reduce_range = False - myobs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) - # Calculate Qparams should return with a warning for observers with no data - qparams = myobs.calculate_qparams() - x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) - result = myobs(x) - result = myobs(y) - self.assertEqual(result, y) - self.assertEqual(myobs.min_val, 1.0) - self.assertEqual(myobs.max_val, 8.0) - qparams = myobs.calculate_qparams() - if reduce_range: - if qscheme == torch.per_tensor_symmetric: - ref_scale = 0.062745 * 255 / 127 - ref_zero_point = 0 if qdtype is torch.qint8 else 128 + ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range), + MovingAverageMinMaxObserver(averaging_constant=0.5, + dtype=qdtype, + qscheme=qscheme, + reduce_range=reduce_range)] + for myobs in ObserverList: + # Calculate Qparams should return with a warning for observers with no data + qparams = myobs.calculate_qparams() + if type(myobs) == MinMaxObserver: + x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) else: - ref_scale = 0.0313725 * 255 / 127 - ref_zero_point = -64 if qdtype is torch.qint8 else 0 - else: - if qscheme == torch.per_tensor_symmetric: - ref_scale = 0.062745 - ref_zero_point = 0 if qdtype is torch.qint8 else 128 + # Moving average of min/max for x and y matches that of + # extreme values for x/y used for minmax observer + x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0]) + + result = myobs(x) + result = myobs(y) + self.assertEqual(result, y) + self.assertEqual(myobs.min_val, 1.0) + self.assertEqual(myobs.max_val, 8.0) + qparams = myobs.calculate_qparams() + if reduce_range: + if qscheme == torch.per_tensor_symmetric: + ref_scale = 0.062745 * 255 / 127 + ref_zero_point = 0 if qdtype is torch.qint8 else 128 + else: + ref_scale = 0.0313725 * 255 / 127 + ref_zero_point = -64 if qdtype is torch.qint8 else 0 else: - ref_scale = 0.0313725 - ref_zero_point = -128 if qdtype is torch.qint8 else 0 - self.assertEqual(qparams[1].item(), ref_zero_point) - self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) - - # Test for serializability - state_dict = myobs.state_dict() - b = io.BytesIO() - torch.save(state_dict, b) - b.seek(0) - loaded_dict = torch.load(b) - for key in state_dict: - self.assertEqual(state_dict[key], loaded_dict[key]) - loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) - loaded_obs.load_state_dict(loaded_dict) - loaded_qparams = loaded_obs.calculate_qparams() - self.assertEqual(myobs.min_val, loaded_obs.min_val) - self.assertEqual(myobs.max_val, loaded_obs.max_val) - self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) + if qscheme == torch.per_tensor_symmetric: + ref_scale = 0.062745 + ref_zero_point = 0 if qdtype is torch.qint8 else 128 + else: + ref_scale = 0.0313725 + ref_zero_point = -128 if qdtype is torch.qint8 else 0 + self.assertEqual(qparams[1].item(), ref_zero_point) + self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) + state_dict = myobs.state_dict() + b = io.BytesIO() + torch.save(state_dict, b) + b.seek(0) + loaded_dict = torch.load(b) + for key in state_dict: + self.assertEqual(state_dict[key], loaded_dict[key]) + loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) + loaded_obs.load_state_dict(loaded_dict) + loaded_qparams = loaded_obs.calculate_qparams() + self.assertEqual(myobs.min_val, loaded_obs.min_val) + self.assertEqual(myobs.max_val, loaded_obs.max_val) + self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)), ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans()) - def test_per_channel_minmax_observer(self, qdtype, qscheme, ch_axis, reduce_range): + def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range): # reduce_range cannot be true for symmetric quantization with uint8 if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric: reduce_range = False - myobs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) - # Calculate qparams should work for empty observers - qparams = myobs.calculate_qparams() - x = torch.tensor( - [ - [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]], - [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], - ] - ) - result = myobs(x) - self.assertEqual(result, x) - qparams = myobs.calculate_qparams() - ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]] - ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]] - per_channel_symmetric_ref_scales = [ - [0.04705882, 0.06274509], - [0.03921569, 0.0627451], - [0.04705882, 0.0627451], - [0.05490196, 0.0627451], - ] - per_channel_affine_ref_scales = [ - [0.02352941, 0.04705882], - [0.03529412, 0.03137255], - [0.03921569, 0.03137255], - [0.04313726, 0.04313726], - ] - per_channel_affine_qint8_zp = [ - [-128, -43], - [-15, -128], - [-26, -128], - [-35, -58], - ] - per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] - - self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis]) - self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis]) - if qscheme == torch.per_channel_symmetric: - ref_scales = per_channel_symmetric_ref_scales[ch_axis] - ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] - else: - ref_scales = per_channel_affine_ref_scales[ch_axis] - ref_zero_points = ( - per_channel_affine_qint8_zp[ch_axis] - if qdtype is torch.qint8 - else per_channel_affine_quint8_zp[ch_axis] + ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range, + ch_axis=ch_axis, + dtype=qdtype, + qscheme=qscheme), + MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5, + reduce_range=reduce_range, + ch_axis=ch_axis, + dtype=qdtype, + qscheme=qscheme)] + + for myobs in ObserverList: + # Calculate qparams should work for empty observers + qparams = myobs.calculate_qparams() + x = torch.tensor( + [ + [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]], + [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], + ] ) + if type(myobs) == MovingAveragePerChannelMinMaxObserver: + # Scaling the input tensor to model change in min/max values + # across batches + result = myobs(0.5 * x) + result = myobs(1.5 * x) + self.assertEqual(result, 1.5 * x) + else: + result = myobs(x) + self.assertEqual(result, x) + + qparams = myobs.calculate_qparams() + ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]] + ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]] + per_channel_symmetric_ref_scales = [ + [0.04705882, 0.06274509], + [0.03921569, 0.0627451], + [0.04705882, 0.0627451], + [0.05490196, 0.0627451], + ] + per_channel_affine_ref_scales = [ + [0.02352941, 0.04705882], + [0.03529412, 0.03137255], + [0.03921569, 0.03137255], + [0.04313726, 0.04313726], + ] + per_channel_affine_qint8_zp = [ + [-128, -43], + [-15, -128], + [-26, -128], + [-35, -58], + ] + per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] - if reduce_range: - ref_scales = [s * 255 / 127 for s in ref_scales] - ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] + self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis]) + self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis]) + if qscheme == torch.per_channel_symmetric: + ref_scales = per_channel_symmetric_ref_scales[ch_axis] + ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] + else: + ref_scales = per_channel_affine_ref_scales[ch_axis] + ref_zero_points = ( + per_channel_affine_qint8_zp[ch_axis] + if qdtype is torch.qint8 + else per_channel_affine_quint8_zp[ch_axis] + ) + + if reduce_range: + ref_scales = [s * 255 / 127 for s in ref_scales] + ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] + + self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype))) + self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) + + # Test for serializability + state_dict = myobs.state_dict() + b = io.BytesIO() + torch.save(state_dict, b) + b.seek(0) + loaded_dict = torch.load(b) + for key in state_dict: + self.assertEqual(state_dict[key], loaded_dict[key]) + loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) + loaded_obs.load_state_dict(loaded_dict) + loaded_qparams = loaded_obs.calculate_qparams() + self.assertEqual(myobs.min_vals, loaded_obs.min_vals) + self.assertEqual(myobs.max_vals, loaded_obs.max_vals) + self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) - self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype))) - self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) + def test_observer_scriptable(self): + obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver()] + for obs in obs_list: + scripted = torch.jit.script(obs) - # Test for serializability - state_dict = myobs.state_dict() - b = io.BytesIO() - torch.save(state_dict, b) - b.seek(0) - loaded_dict = torch.load(b) - for key in state_dict: - self.assertEqual(state_dict[key], loaded_dict[key]) - loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) - loaded_obs.load_state_dict(loaded_dict) - loaded_qparams = loaded_obs.calculate_qparams() - self.assertEqual(myobs.min_vals, loaded_obs.min_vals) - self.assertEqual(myobs.max_vals, loaded_obs.max_vals) - self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) + x = torch.rand(3, 4) + obs(x) + scripted(x) - def test_observer_scriptable(self): - obs = torch.quantization.default_observer() - scripted = torch.jit.script(obs) + self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) - x = torch.rand(3, 4) - obs(x) - scripted(x) + buf = io.BytesIO() + torch.jit.save(scripted, buf) + buf.seek(0) + loaded = torch.jit.load(buf) + self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) - self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams()) + def test_no_qconfig_propagation(self): + model = ModelWithNoQconfigPropagation() + model.qconfig = torch.quantization.default_qconfig + + model = prepare(model) + self.assertTrue(hasattr(model.fc1, 'qconfig'), + "QConfig is expected to propagate") + self.assertFalse(hasattr(model.no_quant_module, 'qconfig'), + "QConfig is expected to NOT propagate") - buf = io.BytesIO() - torch.jit.save(scripted, buf) - buf.seek(0) - loaded = torch.jit.load(buf) - self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" diff --git a/test/test_quantized_models.py b/test/test_quantized_models.py index cf0f41655d550..2fcf2c17f9233 100644 --- a/test/test_quantized_models.py +++ b/test/test_quantized_models.py @@ -75,7 +75,7 @@ def test_fake_quant_true_quant_compare(self): torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) - fq_model.apply(torch.nn._intrinsic.qat.freeze_bn_stats) + fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) @@ -116,7 +116,7 @@ def test_weight_only_activation_only_fakequant(self): torch.quantization.prepare_qat(fq_model) fq_model.eval() fq_model.apply(torch.quantization.disable_fake_quant) - fq_model.apply(torch.nn._intrinsic.qat.freeze_bn_stats) + fq_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) fq_model(calib_data) fq_model.apply(torch.quantization.enable_fake_quant) fq_model.apply(torch.quantization.disable_observer) diff --git a/test/test_quantized_nn_mods.py b/test/test_quantized_nn_mods.py index e647873afc837..37756a597a834 100644 --- a/test/test_quantized_nn_mods.py +++ b/test/test_quantized_nn_mods.py @@ -1,10 +1,10 @@ import torch import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd -import torch.nn._intrinsic.quantized as nnq_fused +import torch.nn.intrinsic.quantized as nnq_fused import torch.nn.quantized.functional as qF from torch.nn.quantized.modules import Conv2d -from torch.nn._intrinsic.quantized import ConvReLU2d +from torch.nn.intrinsic.quantized import ConvReLU2d import torch.quantization from common_utils import run_tests, IS_PPC, TEST_WITH_UBSAN from common_quantization import QuantizationTestCase, prepare_dynamic @@ -228,7 +228,7 @@ def test_relu(self): qengine=st.sampled_from(("qnnpack", "fbgemm")) ) def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_fused, per_channel, qengine): - """test API functionality for nn.quantized.linear and nn._intrinsic.quantized.linear_relu""" + """test API functionality for nn.quantized.linear and nn.intrinsic.quantized.linear_relu""" if qengine not in torch.backends.quantized.supported_engines: return if qengine == 'qnnpack': @@ -272,8 +272,12 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f # ops directly if use_fused: Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) + + self.assertTrue('QuantizedLinearReLU' in str(qlinear)) else: Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) + + self.assertTrue('QuantizedLinear' in str(qlinear)) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict @@ -483,6 +487,8 @@ def test_conv_api(self, use_bias, use_fused, per_channel, qengine): groups=g, bias=use_bias, padding_mode='zeros') + + self.assertTrue('QuantizedConvReLU2d' in str(loaded_conv_under_test)) else: loaded_conv_under_test = Conv2d(in_channels=iC, out_channels=oC, @@ -493,6 +499,7 @@ def test_conv_api(self, use_bias, use_fused, per_channel, qengine): groups=g, bias=use_bias, padding_mode='zeros') + self.assertTrue('QuantizedConv2d' in str(loaded_conv_under_test)) loaded_conv_under_test.load_state_dict(loaded_dict) self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) if use_bias: diff --git a/test/test_rpc_fork.py b/test/test_rpc_fork.py index 5e2432f60183e..8a082b966d338 100644 --- a/test/test_rpc_fork.py +++ b/test/test_rpc_fork.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 from __future__ import absolute_import, division, print_function, unicode_literals +import torch + +# rpc_fork tests use double as the default dtype +torch.set_default_dtype(torch.double) + from rpc_test import RpcTest from common_distributed import MultiProcessTestCase from common_utils import run_tests diff --git a/test/test_sparse.py b/test/test_sparse.py index f7795c68046a9..7eacf9f8d1106 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -3,8 +3,10 @@ import itertools import functools import random +import sys import unittest -from common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, do_test_empty_full, load_tests +from common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ + do_test_empty_full, load_tests, default_floating_dtype from common_cuda import TEST_CUDA from numbers import Number from torch.autograd.gradcheck import gradcheck @@ -93,6 +95,7 @@ def randn(self, *args, **kwargs): # TODO: Put this in torch.cuda.randn return self.value_empty(*args, **kwargs).normal_() + @default_floating_dtype(torch.double) def test_print(self): shape_sparse_dim_nnz = [ ((), 0, 2), @@ -924,6 +927,7 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 0, 100, 0) test_shape(1000, 100, 0, 0) + @default_floating_dtype(torch.double) def test_sparse_addmm(self): def test_shape(m, n, p, nnz, broadcast): if broadcast: @@ -943,6 +947,7 @@ def fn(S, D1, D2): test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) + @default_floating_dtype(torch.double) def test_sparse_mm(self): def test_shape(d1, d2, d3, nnz, transposed): if transposed: @@ -962,6 +967,7 @@ def fn(S, D): test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) + @default_floating_dtype(torch.double) def test_dsmm(self): def test_shape(di, dj, dk, nnz): x = self._gen_sparse(2, nnz, [di, dj])[0] @@ -980,6 +986,7 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 100, 0, 20) @skipIfRocm + @default_floating_dtype(torch.double) def test_hsmm(self): def test_shape(di, dj, dk, nnz): x = self._gen_sparse(2, nnz, [di, dj])[0] @@ -1042,6 +1049,7 @@ def _test_spadd_shape(self, nnz, shape_i, shape_v=None): expected = y + r * self.safeToDense(x_) self.assertEqual(res, expected) + @default_floating_dtype(torch.double) def test_spadd(self): self._test_spadd_shape(10, [5, 6]) self._test_spadd_shape(10, [10, 10, 10]) @@ -1051,6 +1059,7 @@ def test_spadd(self): self._test_spadd_shape(0, [50, 0, 20]) self._test_spadd_shape(0, [50, 30, 0]) + @default_floating_dtype(torch.double) def test_spadd_hybrid(self): self._test_spadd_shape(10, [5, 6], [2, 3]) self._test_spadd_shape(10, [10, 10, 10], [3]) @@ -1072,6 +1081,7 @@ def test_shape(sparse_dims, nnz, with_size): test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0]) @skipIfRocm + @default_floating_dtype(torch.double) def test_sparse_sum(self): def run_tests(S, td=None): @@ -1212,6 +1222,7 @@ def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v=None): # coalesced. self.assertEqual(z._values(), y._values()) + @default_floating_dtype(torch.double) def test_basic_ops(self): self._test_basic_ops_shape(9, 12, [5, 6]) self._test_basic_ops_shape(9, 12, [10, 10, 10]) @@ -1222,6 +1233,7 @@ def test_basic_ops(self): self._test_basic_ops_shape(0, 0, [10, 10, 10]) self._test_basic_ops_shape(0, 0, [10, 10, 0]) + @default_floating_dtype(torch.double) def test_basic_ops_hybrid(self): self._test_basic_ops_shape(9, 12, [5, 6], [2, 3]) self._test_basic_ops_shape(9, 12, [10, 10, 10], [3]) @@ -2054,6 +2066,39 @@ def test_change_tensor_metadata(self): self.assertEqual(list(t.coalesce().indices().size()), [2, 1]) self.assertEqual(list(t.coalesce().values().size()), [1, 3]) + def test_pickle(self): + if sys.version_info[0] == 2: + import cPickle as pickle + else: + import pickle + + shape_sparse_dim_nnz = [ + ((), 0, 2), + ((0,), 0, 10), + ((2,), 0, 3), + ((100, 3), 1, 3), + ((100, 20, 3), 2, 0), + ((10, 0, 3), 0, 3), + ((10, 0, 3), 0, 0), + ] + + for shape, sparse_dim, nnz in shape_sparse_dim_nnz: + indices_shape = torch.Size((sparse_dim, nnz)) + values_shape = torch.Size((nnz,) + shape[sparse_dim:]) + indices = torch.arange(indices_shape.numel(), dtype=self.index_tensor(0).dtype, + device=self.device).view(indices_shape) + for d in range(sparse_dim): + indices[d].clamp_(max=(shape[d] - 1)) # make it valid index + if self.is_uncoalesced and indices.numel() > 0: + indices[:, -1] = indices[:, 0] # make it uncoalesced + values_numel = values_shape.numel() + values = torch.arange(values_numel, dtype=self.value_dtype, + device=self.device).view(values_shape).div_(values_numel / 2.) + sp_tensor = self.sparse_tensor(indices, values, shape) + serialized = pickle.dumps(sp_tensor) + sp_tensor_loaded = pickle.loads(serialized) + self.assertEqual(sp_tensor, sp_tensor_loaded) + class TestUncoalescedSparse(TestSparse): def setUp(self): diff --git a/test/test_torch.py b/test/test_torch.py index 0f6dd380354af..be0a16ef04af0 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -31,7 +31,8 @@ TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \ IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \ IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \ - skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf + skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, \ + default_floating_dtype from multiprocessing.reduction import ForkingPickler from common_device_type import instantiate_device_type_tests, \ skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \ @@ -1319,6 +1320,7 @@ def test_cpow(self): @slowTest @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + @default_floating_dtype(torch.double) def test_einsum(self): # test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f x = torch.randn(5) @@ -1388,6 +1390,7 @@ def do_einsum(*args): self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps)) self.assertTrue(A._version == 0) # check that we do not use inplace ops + @default_floating_dtype(torch.double) def test_sum_all(self): def check_sum_all(tensor): pylist = tensor.reshape(-1).tolist() @@ -1487,6 +1490,7 @@ def test_logsumexp_dim(self): lambda n, d: logsumexp(n, d), use_integral=False) + @default_floating_dtype(torch.double) def test_sum_out(self): x = torch.rand(100, 100) res1 = torch.sum(x, 1) @@ -2924,6 +2928,7 @@ def test_ormqr(self): self.assertEqual(res2, out_holder) @skipIfNoLapack + @default_floating_dtype(torch.double) def test_eig(self): a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00), (-6.49, 3.80, 0.00, 0.00, 0.00), @@ -2967,6 +2972,7 @@ def test_eig(self): self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong') @staticmethod + @default_floating_dtype(torch.double) def _test_fft_ifft_rfft_irfft(self, device='cpu'): def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): x = prepro_fn(torch.randn(*sizes, device=device)) @@ -3746,6 +3752,7 @@ def _test_abs_single(data): for num in abs_zeros: self.assertGreater(math.copysign(1.0, num), 0.0) + @default_floating_dtype(torch.double) def test_hardshrink(self): data_original = torch.tensor([1, 0.5, 0.3, 0.6]).view(2, 2) float_types = [ @@ -3763,6 +3770,7 @@ def test_hardshrink(self): # test non-contiguous case self.assertEqual(torch.tensor([1, 0, 0.5, 0.6]).view(2, 2), data.t().hardshrink(0.3)) + @default_floating_dtype(torch.double) def test_hardshrink_edge_cases(self): def h(t, values, l_expected): for l, expected in l_expected.items(): @@ -6624,11 +6632,12 @@ def add_neg_dim_tests(): # Device-generic tests. Instantiated below and not run directly. class TestTorchDeviceType(TestCase): - def check_internal_mem_overlap(self, inplace_op, num_inputs, device, + def check_internal_mem_overlap(self, inplace_op, num_inputs, + dtype, device, expected_failure=False): if isinstance(inplace_op, str): inplace_op = getattr(torch.Tensor, inplace_op) - input = torch.randn(1, device=device).expand(3, 3) + input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)] if not expected_failure: @@ -7174,7 +7183,8 @@ def test_inverse_many_batches(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_pinverse(self, device): + @dtypes(torch.double) + def test_pinverse(self, device, dtype): from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank def run_test(M): @@ -7191,15 +7201,15 @@ def run_test(M): (3, 2), (5, 3, 2), (7, 5, 3, 2), # fat matrices (2, 3), (5, 2, 3), (7, 5, 2, 3), # thin matrices (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices - M = torch.randn(*sizes, device=device) + M = torch.randn(*sizes, dtype=dtype, device=device) run_test(M) # Test inverse and pseudo-inverse for invertible matrix for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]: matsize = sizes[-1] batchdims = sizes[:-2] - M = fullrank(matsize, *batchdims).to(device=device) - self.assertEqual(torch.eye(matsize, device=device).expand(sizes), M.pinverse().matmul(M), + M = fullrank(matsize, *batchdims, dtype=dtype, device=device) + self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M), 1e-7, 'pseudo-inverse for invertible matrix') @skipCUDAIfNoMagma @@ -7237,7 +7247,8 @@ def test_matrix_rank(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_matrix_power(self, device): + @dtypes(torch.double) + def test_matrix_power(self, device, dtype): def run_test(M, sign=1): if sign == -1: M = M.inverse() @@ -7257,47 +7268,49 @@ def run_test(M, sign=1): self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M)) # Single matrix - M = torch.randn(5, 5, device=device) + M = torch.randn(5, 5, dtype=dtype, device=device) run_test(M) # Batch matrices - M = torch.randn(3, 3, 3, device=device) + M = torch.randn(3, 3, 3, dtype=dtype, device=device) run_test(M) # Many batch matrices - M = torch.randn(2, 3, 3, 3, device=device) + M = torch.randn(2, 3, 3, 3, dtype=dtype, device=device) run_test(M) # This is for negative powers from common_utils import random_fullrank_matrix_distinct_singular_value - M = random_fullrank_matrix_distinct_singular_value(5).to(device) + M = random_fullrank_matrix_distinct_singular_value(5, dtype=dtype, device=device) run_test(M, sign=-1) - M = random_fullrank_matrix_distinct_singular_value(3, 3).to(device) + M = random_fullrank_matrix_distinct_singular_value(3, 3, dtype=dtype, device=device) run_test(M, sign=-1) - M = random_fullrank_matrix_distinct_singular_value(3, 2, 3).to(device) + M = random_fullrank_matrix_distinct_singular_value(3, 2, 3, dtype=dtype, device=device) run_test(M, sign=-1) - def test_chain_matmul(self, device): + @dtypes(torch.double) + def test_chain_matmul(self, device, dtype): def product(matrices): for mat in matrices[1:]: matrices[0] = matrices[0].mm(mat) return matrices[0] - def run_test(p, device): + def run_test(p): matrices = [] for (pi, pi_1) in zip(p[:-1], p[1:]): - matrices.append(torch.randn(pi, pi_1, device=device)) + matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device)) self.assertEqual(torch.chain_matmul(*matrices), product(matrices)) - run_test([10, 20, 30, 5], device) - run_test([15, 5, 10, 20, 25], device) + run_test([10, 20, 30, 5]) + run_test([15, 5, 10, 20, 25]) @slowTest @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_det_logdet_slogdet(self, device): + @dtypes(torch.double) + def test_det_logdet_slogdet(self, device, dtype): def reference_slogdet(M): if TEST_NUMPY: sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) @@ -7349,8 +7362,8 @@ def test_single_det(M, target, desc): else: self.assertEqual(logdet.exp(), target_logabsdet.exp(), 1e-7, '{} (logdet non-negative case)'.format(desc)) - eye = torch.eye(5, device=device) - test_single_det(eye, (torch.ones((), device=device), torch.zeros((), device=device)), 'identity') + eye = torch.eye(5, dtype=dtype, device=device) + test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity') def test(M): assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' @@ -7442,23 +7455,23 @@ def get_random_mat_scale(n): for n in [5, 10, 25]: scale = get_random_mat_scale(n) - test(torch.randn(n, n, device=device) * scale) - r = torch.randn(n, n, device=device) * scale + test(torch.randn(n, n, dtype=dtype, device=device) * scale) + r = torch.randn(n, n, dtype=dtype, device=device) * scale # symmetric psd test(r.mm(r.t())) # symmetric pd - r = torch.randn(n, n, device=device) * scale - test(r.mm(r.t()) + torch.eye(n, device=device) * 1e-6) + r = torch.randn(n, n, dtype=dtype, device=device) * scale + test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6) # symmetric - r = torch.randn(n, n, device=device) * scale + r = torch.randn(n, n, dtype=dtype, device=device) * scale for i in range(n): for j in range(i): r[i, j] = r[j, i] test(r) # non-contiguous - test((torch.randn(n, n, n + 1, device=device) * scale)[:, 2, 1:]) + test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:]) # det = 0 - r = torch.randn(n, n, device=device) * scale + r = torch.randn(n, n, dtype=dtype, device=device) * scale u, s, v = r.svd() if reference_slogdet(u)[0] < 0: u = -u @@ -7470,14 +7483,15 @@ def get_random_mat_scale(n): # Small values to test numerical stability. Note that we don't scale # this matrix. - r = torch.randn(512, 512, device=device) + r = torch.randn(512, 512, dtype=dtype, device=device) u, s, v = r.svd() s.fill_(1. / (100 * s.numel())) test(u.mm(s.diag()).mm(v)) @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_det_logdet_slogdet_batched(self, device): + @dtypes(torch.double) + def test_det_logdet_slogdet_batched(self, device, dtype): from common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix, random_symmetric_pd_matrix, random_square_matrix_of_rank) @@ -7490,15 +7504,15 @@ def run_test(matsize, batchdims, mat_chars): for idx in range(num_matrices): mat_type = idx % len(mat_chars) if mat_chars[mat_type] == 'sym': - list_of_matrices.append(random_symmetric_matrix(matsize).to(device=device)) + list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device)) elif mat_chars[mat_type] == 'sym_psd': - list_of_matrices.append(random_symmetric_psd_matrix(matsize).to(device=device)) + list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device)) elif mat_chars[mat_type] == 'sym_pd': - list_of_matrices.append(random_symmetric_pd_matrix(matsize).to(device=device)) + list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device)) elif mat_chars[mat_type] == 'sing': - list_of_matrices.append(torch.ones(matsize, matsize, device=device)) + list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device)) elif mat_chars[mat_type] == 'non_sing': - list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize).to(device=device)) + list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device)) full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize))) @@ -7525,22 +7539,28 @@ def run_test(matsize, batchdims, mat_chars): run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) + def solve_test_helper(self, A_dims, b_dims, device, dtype): + from common_utils import random_fullrank_matrix_distinct_singular_value + + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device) + return b, A + @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_solve(self, device): - from common_utils import solve_test_helper + @dtypes(torch.double) + def test_solve(self, device, dtype): for (k, n) in zip([2, 3, 5], [3, 5, 7]): - b, A = solve_test_helper((n,), (n, k), lambda t: t.to(device)) + b, A = self.solve_test_helper((n,), (n, k), device, dtype) x = torch.solve(b, A)[0] self.assertLessEqual(b.dist(A.mm(x)), 1e-12) @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_solve_batched(self, device): - from common_utils import solve_test_helper - - def solve_batch_helper(A_dims, b_dims, device): - b, A = solve_test_helper(A_dims, b_dims, lambda t: t.to(device)) + @dtypes(torch.double) + def test_solve_batched(self, device, dtype): + def solve_batch_helper(A_dims, b_dims): + b, A = self.solve_test_helper(A_dims, b_dims, device, dtype) x_exp_list = [] for i in range(b_dims[0]): x_exp_list.append(torch.solve(b[i], A[i])[0]) @@ -7550,75 +7570,80 @@ def solve_batch_helper(A_dims, b_dims, device): self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check for batchsize in [1, 3, 4]: - solve_batch_helper((5, batchsize), (batchsize, 5, 10), device) + solve_batch_helper((5, batchsize), (batchsize, 5, 10)) @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_solve_batched_non_contiguous(self, device): + @dtypes(torch.double) + def test_solve_batched_non_contiguous(self, device, dtype): from numpy.linalg import solve from common_utils import random_fullrank_matrix_distinct_singular_value - A = random_fullrank_matrix_distinct_singular_value(2, 2).to(device).permute(1, 0, 2) - b = torch.randn(2, 2, 2, device=device).permute(2, 1, 0) + A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, + device=device).permute(1, 0, 2) + b = torch.randn(2, 2, 2, dtype=dtype, device=device).permute(2, 1, 0) x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(device) + x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(dtype=dtype, device=device) self.assertEqual(x.data, x_exp) @slowTest @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_solve_batched_many_batches(self, device): - from common_utils import solve_test_helper - - b, A = solve_test_helper((5, 256, 256), (5, 1), lambda t: t.to(device)) + @dtypes(torch.double) + def test_solve_batched_many_batches(self, device, dtype): + b, A = self.solve_test_helper((5, 256, 256), (5, 1), device, dtype) x, _ = torch.solve(b, A) self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - b, A = solve_test_helper((3,), (512, 512, 3, 1), lambda t: t.to(device)) + b, A = self.solve_test_helper((3,), (512, 512, 3, 1), device, dtype) x, _ = torch.solve(b, A) self.assertEqual(torch.matmul(A, x), b) @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_solve_batched_broadcasting(self, device): + @dtypes(torch.double) + def test_solve_batched_broadcasting(self, device, dtype): from numpy.linalg import solve - from common_utils import solve_test_helper - def cast(t): - return t.to(device) - - def run_test(A_dims, b_dims, cast): + def run_test(A_dims, b_dims): A_matrix_size = A_dims[-1] A_batch_dims = A_dims[:-2] - b, A = solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, cast) + b, A = self.solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, device, dtype) x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(x, cast(x_exp)) + x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(dtype=dtype, device=device) + self.assertEqual(x, x_exp) # test against numpy.linalg.solve for upper in [True, False]: - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), cast) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), cast) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast) # broadcasting A & b + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b + + def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype): + from common_utils import random_symmetric_pd_matrix + + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = random_symmetric_pd_matrix(*A_dims, dtype=dtype, device=device) + L = torch.cholesky(A, upper=upper) + return b, A, L @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_cholesky_solve(self, device): - from common_utils import cholesky_solve_test_helper + @dtypes(torch.double) + def test_cholesky_solve(self, device, dtype): for (k, n), upper in product(zip([2, 3, 5], [3, 5, 7]), [True, False]): - b, A, L = cholesky_solve_test_helper((n,), (n, k), lambda t: t.to(device), upper) + b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype) x = torch.cholesky_solve(b, L, upper=upper) self.assertLessEqual(b.dist(A.mm(x)), 1e-12) @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_cholesky_solve_batched(self, device): - from common_utils import cholesky_solve_test_helper - - def cholesky_solve_batch_helper(A_dims, b_dims, cast, upper): - b, A, L = cholesky_solve_test_helper(A_dims, b_dims, cast, upper) + @dtypes(torch.double) + def test_cholesky_solve_batched(self, device, dtype): + def cholesky_solve_batch_helper(A_dims, b_dims, upper): + b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype) x_exp_list = [] for i in range(b_dims[0]): x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper)) @@ -7628,19 +7653,20 @@ def cholesky_solve_batch_helper(A_dims, b_dims, cast, upper): self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 2e-12) # Correctness check for upper, batchsize in product([True, False], [1, 3, 4]): - cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), lambda t: t.to(device), upper) + cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper) @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_cholesky_solve_batched_non_contiguous(self, device): + @dtypes(torch.double) + def test_cholesky_solve_batched_non_contiguous(self, device, dtype): from numpy.linalg import solve from common_utils import random_symmetric_pd_matrix for upper in [True, False]: - A = random_symmetric_pd_matrix(2, 2) - b = torch.randn(2, 2, 2) - x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device) + A = random_symmetric_pd_matrix(2, 2, dtype=dtype, device='cpu') + b = torch.randn(2, 2, 2, dtype=dtype, device='cpu') + x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(dtype=dtype, device=device) A = A.to(device).permute(0, 2, 1) b = b.to(device).permute(2, 1, 0) assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" @@ -7651,51 +7677,50 @@ def test_cholesky_solve_batched_non_contiguous(self, device): @slowTest @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_cholesky_solve_batched_many_batches(self, device): - from common_utils import cholesky_solve_test_helper - + @dtypes(torch.double) + def test_cholesky_solve_batched_many_batches(self, device, dtype): for upper in [True, False]: - b, A, L = cholesky_solve_test_helper((5, 256, 256), (5, 10), lambda t: t.to(device), upper) + b, A, L = self.cholesky_solve_test_helper((5, 256, 256), (5, 10), upper, device, dtype) x = torch.cholesky_solve(b, L, upper) self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 10))) - b, A, L = cholesky_solve_test_helper((5,), (512, 512, 5, 10), lambda t: t.to(device), upper) + b, A, L = self.cholesky_solve_test_helper((5,), (512, 512, 5, 10), upper, device, dtype) x = torch.cholesky_solve(b, L, upper) self.assertEqual(torch.matmul(A, x), b) @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_cholesky_solve_batched_broadcasting(self, device): + @dtypes(torch.double) + def test_cholesky_solve_batched_broadcasting(self, device, dtype): from numpy.linalg import solve from common_utils import random_symmetric_pd_matrix - def cast(t): - return t.to(device) - - def run_test(A_dims, b_dims, cast, upper): + def run_test(A_dims, b_dims, upper): A_matrix_size = A_dims[-1] A_batch_dims = A_dims[:-2] - A = random_symmetric_pd_matrix(A_matrix_size, *A_batch_dims) - b = torch.randn(*b_dims) - x_exp = torch.Tensor(solve(A.numpy(), b.numpy())) - A, b = cast(A), cast(b) + A = random_symmetric_pd_matrix(A_matrix_size, *A_batch_dims, + dtype=dtype, device='cpu') + b = torch.randn(*b_dims, dtype=dtype, device='cpu') + x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device) + A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device) L = torch.cholesky(A, upper) x = torch.cholesky_solve(b, L, upper=upper) - self.assertEqual(x, cast(x_exp)) + self.assertEqual(x, x_exp) # test against numpy.linalg.solve for upper in [True, False]: - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, upper) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), cast, upper) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), cast, upper) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper) # broadcasting A & b + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), upper) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), upper) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper) # broadcasting A & b @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_cholesky_inverse(self, device): + @dtypes(torch.double) + def test_cholesky_inverse(self, device, dtype): from common_utils import random_symmetric_pd_matrix - a = random_symmetric_pd_matrix(5).to(device) + a = random_symmetric_pd_matrix(5, dtype=dtype, device=device) # compute inverse directly inv0 = torch.inverse(a) @@ -7719,11 +7744,12 @@ def test_cholesky_inverse(self, device): @skipCUDAIf(True, "See issue #26789.") @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_cholesky_batched_many_batches(self, device): + @dtypes(torch.double) + def test_cholesky_batched_many_batches(self, device, dtype): from common_utils import random_symmetric_pd_matrix def cholesky_test_helper(n, batchsize, device, upper): - A = random_symmetric_pd_matrix(n, batchsize).to(device) + A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device) chol_fact = torch.cholesky(A, upper=upper) if upper: # Correctness check @@ -7741,22 +7767,24 @@ def cholesky_test_helper(n, batchsize, device, upper): @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_cholesky_batched(self, device): + @dtypes(torch.double) + def test_cholesky_batched(self, device, dtype): from common_utils import random_symmetric_pd_matrix - def cholesky_test_helper(n, batch_dims, device, upper): - A = random_symmetric_pd_matrix(n, *batch_dims).to(device) + def cholesky_test_helper(n, batch_dims, upper): + A = random_symmetric_pd_matrix(n, *batch_dims, dtype=dtype, device=device) cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) cholesky_exp = cholesky_exp.reshape_as(A) self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]): - cholesky_test_helper(3, batchsize, device, upper) + cholesky_test_helper(3, batchsize, upper) @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_cholesky(self, device): - x = torch.rand(10, 10, device=device) + 1e-1 + @dtypes(torch.double) + def test_cholesky(self, device, dtype): + x = torch.rand(10, 10, dtype=dtype, device=device) + 1e-1 A = torch.mm(x, x.t()) # default Case @@ -8326,13 +8354,15 @@ def delitem(): self.assertRaises(TypeError, delitem) @skipCUDANonDefaultStreamIf(True) - def test_advancedindex(self, device): + @dtypes(torch.double) + @default_floating_dtype(torch.double) + def test_advancedindex(self, device, dtype): # Tests for Integer Array Indexing, Part I - Purely integer array # indexing def consec(size, start=1): numel = reduce(lambda x, y: x * y, size, 1) - sequence = torch.ones(numel).cumsum(0) + sequence = torch.ones(numel, dtype=dtype).cumsum(0) sequence.add_(start - 1) return sequence.view(*size) @@ -8382,7 +8412,7 @@ def validate_setting(x): # strided is [1, 3, 5, 7] reference = consec((10,)).to(device) - strided = torch.Tensor().to(device) + strided = torch.tensor((), dtype=dtype, device=device) strided.set_(reference.storage(), storage_offset=0, size=torch.Size([4]), stride=[2]) @@ -8395,7 +8425,7 @@ def validate_setting(x): torch.Tensor([[5, 3], [1, 7]])) # stride is [4, 8] - strided = torch.Tensor().to(device) + strided = torch.tensor((), dtype=dtype, device=device) strided.set_(reference.storage(), storage_offset=4, size=torch.Size([2]), stride=[4]) self.assertEqual(strided[[0]], torch.Tensor([5])) @@ -8442,8 +8472,8 @@ def validate_setting(x): reference[ri([0]), ri([1])] = -1 self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1])) reference[ri([0, 1, 2]), ri([0])] = torch.Tensor([-1, 2, -4]).to(device) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1, - 2, -4])) + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], + torch.Tensor([-1, 2, -4])) reference[rows, columns] = torch.Tensor([[4, 6], [2, 3]]).to(device) self.assertEqual(reference[rows, columns], torch.Tensor([[4, 6], [2, 3]])) @@ -8452,7 +8482,7 @@ def validate_setting(x): reference = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7], - [8, 9, 10, 11]]).to(device).t_() + [8, 9, 10, 11]]).to(dtype=dtype, device=device).t_() # Transposed: [[0, 4, 8], # [1, 5, 9], @@ -8504,8 +8534,8 @@ def validate_setting(x): # strided is [[1 3 5 7], # [9 11 13 15]] - reference = torch.arange(0., 24).view(3, 8).to(device) - strided = torch.Tensor().to(device) + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), stride=[8, 2]) @@ -8542,26 +8572,26 @@ def validate_setting(x): # strided is [[10, 11], # [17, 18]] - reference = torch.arange(0., 24).view(3, 8).to(device) - strided = torch.Tensor().to(device) + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11])) strided[ri([0]), ri([1])] = -1 self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1])) - reference = torch.arange(0., 24).view(3, 8).to(device) - strided = torch.Tensor().to(device) + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([11, 17])) - strided[ri([0, 1]), ri([1, 0])] = torch.Tensor([-1, 2]).to(device) + strided[ri([0, 1]), ri([1, 0])] = torch.Tensor([-1, 2]).to(dtype=dtype, device=device) self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1, 2])) - reference = torch.arange(0., 24).view(3, 8).to(device) - strided = torch.Tensor().to(device) + reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8) + strided = torch.tensor((), dtype=dtype, device=device) strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), stride=[7, 1]) @@ -8571,7 +8601,7 @@ def validate_setting(x): [0, 1]]) self.assertEqual(strided[rows, columns], torch.Tensor([[10, 11], [17, 18]])) - strided[rows, columns] = torch.Tensor([[4, 6], [2, 3]]).to(device) + strided[rows, columns] = torch.Tensor([[4, 6], [2, 3]]).to(dtype=dtype, device=device) self.assertEqual(strided[rows, columns], torch.Tensor([[4, 6], [2, 3]])) @@ -8590,7 +8620,7 @@ def validate_setting(x): reference[ri([1]), ri([0, 2]), ri([3])] # test invalid index fails - reference = torch.empty(10, device=device) + reference = torch.empty(10, dtype=dtype, device=device) # can't test cuda because it is a device assert if not reference.is_cuda: for err_idx in (10, -11): @@ -8608,8 +8638,7 @@ def validate_setting(x): def tensor_indices_to_np(tensor, indices): # convert the Torch Tensor to a numpy array - if (tensor.is_cuda): - tensor = tensor.cpu() + tensor = tensor.to(device='cpu') npt = tensor.numpy() # convert indices @@ -8636,13 +8665,13 @@ def set_numpy(tensor, indices, value): def assert_get_eq(tensor, indexer): self.assertEqual(tensor[indexer], - get_numpy(tensor, indexer).to(device)) + get_numpy(tensor, indexer).to(dtype=dtype, device=device)) def assert_set_eq(tensor, indexer, val): pyt = tensor.clone() numt = tensor.clone() pyt[indexer] = val - numt = torch.Tensor(set_numpy(numt, indexer, val)).to(device) + numt = torch.Tensor(set_numpy(numt, indexer, val)).to(dtype=dtype, device=device) self.assertEqual(pyt, numt) def assert_backward_eq(tensor, indexer): @@ -8665,7 +8694,7 @@ def get_set_tensor(indexed, indexer): # 5 6 7 8 9 # 10 11 12 13 14 # 15 16 17 18 19 - reference = torch.arange(0., 20).view(4, 5).to(device) + reference = torch.arange(0., 20, dtype=dtype, device=device).view(4, 5) indices_to_test = [ # grab the second, fourth columns @@ -8697,7 +8726,7 @@ def get_set_tensor(indexed, indexer): indexer, get_set_tensor(reference, indexer)) - reference = torch.arange(0., 160).view(4, 8, 5).to(device) + reference = torch.arange(0., 160, dtype=dtype, device=device).view(4, 8, 5) indices_to_test = [ [slice(None), slice(None), [0, 3, 4]], @@ -8752,7 +8781,7 @@ def get_set_tensor(indexed, indexer): if torch.cuda.is_available(): assert_backward_eq(reference, indexer) - reference = torch.arange(0., 1296).view(3, 9, 8, 6).to(device) + reference = torch.arange(0., 1296, dtype=dtype, device=device).view(3, 9, 8, 6) indices_to_test = [ [slice(None), slice(None), slice(None), [0, 3, 4]], @@ -8834,14 +8863,15 @@ def get_set_tensor(indexed, indexer): assert_backward_eq(reference, indexer) def test_advancedindex_big(self, device): - reference = torch.arange(0, 123344).int().to(device) + reference = torch.arange(0, 123344, dtype=torch.int, device=device) self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], torch.LongTensor([0, 123, 44488, 68807, 123343])) - def test_kthvalue(self, device): + @dtypes(torch.double) + def test_kthvalue(self, device, dtype): SIZE = 50 - x = torch.rand(SIZE, SIZE, SIZE, device=device) + x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device) x0 = x.clone() k = random.randint(1, SIZE) @@ -8852,7 +8882,7 @@ def test_kthvalue(self, device): self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) # test use of result tensors k = random.randint(1, SIZE) - res1val = torch.tensor([], device=device) + res1val = torch.tensor([], dtype=dtype, device=device) res1ind = torch.tensor([], dtype=torch.long, device=device) torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind)) res2val, res2ind = torch.sort(x) @@ -8879,13 +8909,13 @@ def test_kthvalue(self, device): self.assertEqual(x, x0, 0) # simple test case (with repetitions) - y = torch.tensor((3., 5, 4, 1, 1, 5), device=device) + y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device) self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0) self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0) # simple test case (with NaN) SIZE = 50 - x = torch.rand(SIZE, SIZE, SIZE, device=device) + x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device) x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1] res2val, res2ind = torch.sort(x) @@ -8897,12 +8927,13 @@ def test_kthvalue(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_lu_solve_batched_non_contiguous(self, device): + @dtypes(torch.double) + def test_lu_solve_batched_non_contiguous(self, device, dtype): from numpy.linalg import solve from common_utils import random_fullrank_matrix_distinct_singular_value - A = random_fullrank_matrix_distinct_singular_value(2, 2) - b = torch.randn(2, 2, 2) + A = random_fullrank_matrix_distinct_singular_value(2, 2, dtype=dtype, device='cpu') + b = torch.randn(2, 2, 2, dtype=dtype, device='cpu') x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device) A = A.to(device).permute(0, 2, 1) b = b.to(device).permute(2, 1, 0) @@ -8911,47 +8942,95 @@ def test_lu_solve_batched_non_contiguous(self, device): x = torch.lu_solve(b, LU_data, LU_pivots) self.assertEqual(x, x_exp) - @slowTest + def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype): + from common_utils import random_fullrank_matrix_distinct_singular_value + + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = random_fullrank_matrix_distinct_singular_value(*A_dims, dtype=dtype, device=device) + LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot) + self.assertEqual(info, torch.zeros_like(info)) + return b, A, LU_data, LU_pivots + + @skipCPUIfNoLapack + @skipCUDAIfNoMagma + @dtypes(torch.double) + def test_lu_solve(self, device, dtype): + def sub_test(pivot): + for k, n in zip([2, 3, 5], [3, 5, 7]): + b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n,), (n, k), pivot, device, dtype) + x = torch.lu_solve(b, LU_data, LU_pivots) + self.assertLessEqual(b.dist(A.mm(x)), 1e-12) + + sub_test(True) + if self.device_type == 'cuda': + sub_test(False) + @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_lu_solve_batched_many_batches(self, device): - from common_utils import lu_solve_test_helper + @dtypes(torch.double) + def test_lu_solve_batched(self, device, dtype): + def sub_test(pivot): + def lu_solve_batch_test_helper(A_dims, b_dims, pivot): + b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output + self.assertEqual(x_exp, x_act) # Equality check + self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check + + for batchsize in [1, 3, 4]: + lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), pivot) + + # Tests tensors with 0 elements + b = torch.randn(3, 0, 3, dtype=dtype, device=device) + A = torch.randn(3, 0, 0, dtype=dtype, device=device) + LU_data, LU_pivots = torch.lu(A) + self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) - def cast(t): - return t.to(device) + sub_test(True) + if self.device_type == 'cuda': + sub_test(False) - def run_test(A_dims, b_dims, cast): - b, A, LU_data, LU_pivots = lu_solve_test_helper(self, A_dims, b_dims, cast, True) + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_lu_solve_batched_many_batches(self, device, dtype): + def run_test(A_dims, b_dims): + b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype) x = torch.lu_solve(b, LU_data, LU_pivots) b_ = torch.matmul(A, x) self.assertEqual(b_, b.expand_as(b_)) - run_test((5, 65536), (65536, 5, 10), cast) - run_test((5, 262144), (262144, 5, 10), cast) + run_test((5, 65536), (65536, 5, 10)) + run_test((5, 262144), (262144, 5, 10)) @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_lu_solve_batched_broadcasting(self, device): + @dtypes(torch.double) + def test_lu_solve_batched_broadcasting(self, device, dtype): from numpy.linalg import solve from common_utils import random_fullrank_matrix_distinct_singular_value - def run_test(A_dims, b_dims, device, pivot=True): + def run_test(A_dims, b_dims, pivot=True): A_matrix_size = A_dims[-1] A_batch_dims = A_dims[:-2] - A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims) - b = torch.randn(*b_dims) - x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(device) + A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims, dtype=dtype) + b = torch.randn(*b_dims, dtype=dtype) + x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(dtype=dtype, device=device) A, b = A.to(device), b.to(device) LU_data, LU_pivots = torch.lu(A, pivot=pivot) x = torch.lu_solve(b, LU_data, LU_pivots) self.assertEqual(x, x_exp) # test against numpy.linalg.solve - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), device) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), device) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device) # broadcasting A & b + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6)) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2)) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)) # broadcasting A & b def test_dim_reduction(self, device): example = [[-1, 2, 1], [5, 3, 6]] @@ -9100,13 +9179,14 @@ def test_rpow(self, device): @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_symeig(self, device): + @dtypes(torch.double) + def test_symeig(self, device, dtype): from common_utils import random_symmetric_matrix def run_test(dims, eigenvectors, upper): - x = random_symmetric_matrix(*dims).to(device) - oute = torch.empty(dims[1:] + dims[:1], device=device) - outv = torch.empty(dims[1:] + dims[:1] * 2, device=device) + x = random_symmetric_matrix(*dims, dtype=dtype, device=device) + oute = torch.empty(dims[1:] + dims[:1], dtype=dtype, device=device) + outv = torch.empty(dims[1:] + dims[:1] * 2, dtype=dtype, device=device) torch.symeig(x, eigenvectors=eigenvectors, upper=upper, out=(oute, outv)) if eigenvectors: @@ -9122,7 +9202,7 @@ def run_test(dims, eigenvectors, upper): self.assertEqual(resv, outv, "outputs of symeig and symeig with out don't match") # test non-contiguous - x = random_symmetric_matrix(*dims).to(device) + x = random_symmetric_matrix(*dims, dtype=dtype, device=device) n_dim = len(dims) + 1 # Reverse the batch dimensions and the matrix dimensions and then concat them x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) @@ -9142,10 +9222,13 @@ def run_test(dims, eigenvectors, upper): @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_svd(self, device): + @dtypes(torch.double) + def test_svd(self, device, dtype): def run_test(dims, some, compute_uv): - x = torch.randn(*dims, device=device) - outu, outs, outv = torch.Tensor().to(device), torch.Tensor().to(device), torch.Tensor().to(device) + x = torch.randn(*dims, dtype=dtype, device=device) + outu = torch.tensor((), dtype=dtype, device=device) + outs = torch.tensor((), dtype=dtype, device=device) + outv = torch.tensor((), dtype=dtype, device=device) torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) if compute_uv: @@ -9169,7 +9252,7 @@ def run_test(dims, some, compute_uv): self.assertEqual(resv, outv, 'outputs of svd and svd with out differ') # test non-contiguous - x = torch.randn(*dims, device=device) + x = torch.randn(*dims, dtype=dtype, device=device) n_dim = len(dims) # Reverse the batch dimensions and the matrix dimensions and then concat them x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) @@ -9392,39 +9475,75 @@ def test_geqrf(self, device): self.assertEqual(b, b_placeholder) self.assertEqual(c, c_placeholder) + def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular, + device, dtype): + triangle_function = torch.triu if upper else torch.tril + b = torch.randn(*b_dims, dtype=dtype, device=device) + A = torch.randn(*A_dims, dtype=dtype, device=device) + A_triangular = triangle_function(A) + if unitriangular: + A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) + return b, A_triangular + @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_triangular_solve(self, device): - from common_utils import triangular_solve_test_helper + @dtypes(torch.double) + def test_triangular_solve(self, device, dtype): for (k, n), (upper, unitriangular, transpose) in product(zip([2, 3, 5], [3, 5, 7]), product([True, False], repeat=3)): - b, A = triangular_solve_test_helper((n, n), (n, k), lambda t: t.to(device), upper, unitriangular) + b, A = self.triangular_solve_test_helper((n, n), (n, k), upper, + unitriangular, device, dtype) x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] if transpose: self.assertLessEqual(b.dist(A.t().mm(x)), 4e-12) else: self.assertLessEqual(b.dist(A.mm(x)), 4e-12) - @slowTest - @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_triangular_solve_batched_many_batches(self, device): - from common_utils import triangular_solve_test_helper + @skipCUDAIfNoMagma + @dtypes(torch.double) + def test_triangular_solve_batched(self, device, dtype): + def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose): + b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, + unitriangular, device, dtype) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, + unitriangular=unitriangular, + transpose=transpose)[0]) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.triangular_solve(b, A, upper=upper, + unitriangular=unitriangular, + transpose=transpose)[0] # Actual output + self.assertEqual(x_act, x_exp) # Equality check + if transpose: + self.assertLessEqual(b.dist(torch.matmul(A.transpose(-2, -1), x_act)), 3e-12) # Correctness check + else: + self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 3e-12) # Correctness check + + for (upper, unitriangular, transpose), batchsize in product(product([True, False], repeat=3), [1, 3, 4]): + triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), + upper, unitriangular, transpose) - def cast(t): - return t.to(device) + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.double) + def test_triangular_solve_batched_many_batches(self, device, dtype): for upper, transpose, unitriangular in product([True, False], repeat=3): - b, A = triangular_solve_test_helper((256, 256, 5, 5), (5, 1), cast, upper, unitriangular) + b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1), + upper, unitriangular, device, dtype) x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular) if transpose: A = A.transpose(-2, -1) self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - b, A = triangular_solve_test_helper((3, 3), (512, 512, 3, 1), cast, upper, unitriangular) - x, _ = torch.triangular_solve(b, A, - upper=upper, transpose=transpose, unitriangular=unitriangular) + b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1), + upper, unitriangular, device, dtype) + x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose, + unitriangular=unitriangular) if transpose: A = A.transpose(-2, -1) self.assertEqual(torch.matmul(A, x), b) @@ -9432,9 +9551,9 @@ def cast(t): @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_SCIPY, "SciPy not found") - def test_triangular_solve_batched_broadcasting(self, device): + @dtypes(torch.double) + def test_triangular_solve_batched_broadcasting(self, device, dtype): from scipy.linalg import solve_triangular as tri_solve - from common_utils import triangular_solve_test_helper def scipy_tri_solve_batched(A, B, upper, trans, diag): batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] @@ -9450,7 +9569,8 @@ def scipy_tri_solve_batched(A, B, upper, trans, diag): return flat_X.reshape(expand_B.shape) def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): - b, A = triangular_solve_test_helper(A_dims, b_dims, lambda t: t.to(device), upper, unitriangular) + b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper, + unitriangular, device, dtype) x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), upper, transpose, unitriangular)) x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] @@ -9466,13 +9586,15 @@ def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): @skipCUDAIfNoMagma @skipCPUIfNoLapack - def test_lstsq(self, device): - def cast_fn(tensor): - return tensor.to(device=device) + @dtypes(torch.double) + @default_floating_dtype(torch.double) + def test_lstsq(self, device, dtype): + def cast_fn(t): + return t.to(dtype=dtype, device=device) def _test_underdetermined(a, b, expectedNorm): - # underdetermined systems are not supported on the GPU - if not torch.device(device).type == 'cpu': + # underdetermined systems are only supported on CPU + if not self.device_type == 'cpu': return m = a.size()[0] @@ -9486,8 +9608,8 @@ def _test_underdetermined(a, b, expectedNorm): self.assertEqual(b, b_copy, 0) self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8) - ta = cast_fn(torch.Tensor()) - tb = cast_fn(torch.Tensor()) + ta = torch.tensor((), dtype=dtype, device=device) + tb = torch.tensor((), dtype=dtype, device=device) res2 = torch.lstsq(b, a, out=(tb, ta))[0] self.assertEqual(a, a_copy, 0) self.assertEqual(b, b_copy, 0) @@ -9524,8 +9646,8 @@ def check_norm(a, b, expected_norm, gels_result): self.assertEqual(b, b_copy, 0) check_norm(a, b, expectedNorm, res1) - ta = cast_fn(torch.Tensor()) - tb = cast_fn(torch.Tensor()) + ta = torch.tensor((), dtype=dtype, device=device) + tb = torch.tensor((), dtype=dtype, device=device) res2 = torch.lstsq(b, a, out=(tb, ta))[0] self.assertEqual(a, a_copy, 0) self.assertEqual(b, b_copy, 0) @@ -9577,8 +9699,8 @@ def check_norm(a, b, expected_norm, gels_result): (4.53, 3.83, -6.64, 2.06)))).t() b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48, -5.28), (9.35, -4.43, -0.70, -0.26)))).t() - ta = cast_fn(torch.Tensor()) - tb = cast_fn(torch.Tensor()) + ta = torch.tensor((), dtype=dtype, device=device) + tb = torch.tensor((), dtype=dtype, device=device) torch.lstsq(b, a, out=(tb, ta)) self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8) torch.lstsq(b, a, out=(tb, ta)) @@ -11563,9 +11685,10 @@ def non_zero_rand(size, dtype, device): non_zero_rand((2, 2), dtype=dtype, device=device)) # TODO: run on non-native device types - def test_unary_out_op_mem_overlap(self, device): + @dtypes(torch.double) + def test_unary_out_op_mem_overlap(self, device, dtype): sz = 3 - doubles = torch.randn(2 * sz, device=device) + doubles = torch.randn(2 * sz, dtype=dtype, device=device) positives = torch.randint(1, 100, (2 * sz,), device=device).double() ints = torch.randint(-100, 100, (2 * sz,), device=device) unary_mem_overlap_cases = [ @@ -11603,7 +11726,7 @@ def test_unary_out_op_mem_overlap(self, device): ("log", positives, True, True, 'cpu'), ("log", positives, True, True, 'cuda'), ("log10", positives, True, True, 'cpu'), - ("log10", positives, False, True, 'cuda'), + ("log10", positives, True, True, 'cuda'), ("log1p", positives, True, True, 'cpu'), ("log1p", positives, False, True, 'cuda'), ("log2", positives, True, True, 'cpu'), @@ -11642,10 +11765,11 @@ def test_unary_out_op_mem_overlap(self, device): self.unary_check_input_output_mem_overlap(inputs, sz, out_fn, expected_failure=not has_input_output_mem_overlap_check) - self.check_internal_mem_overlap(in_fn, num_inputs=1, device=dev, + self.check_internal_mem_overlap(in_fn, 1, dtype, dev, expected_failure=not has_internal_mem_overlap_check) - def test_binary_op_mem_overlap(self, device): + @dtypes(torch.double) + def test_binary_op_mem_overlap(self, device, dtype): ops = [ ("add", True, True, 'cpu'), ("add", True, True, 'cuda'), @@ -11666,13 +11790,14 @@ def test_binary_op_mem_overlap(self, device): out_op = getattr(torch, fn) inplace_op = getattr(torch.Tensor, fn + '_') self.check_internal_mem_overlap( - inplace_op, num_inputs=2, device=device, + inplace_op, 2, dtype, device, expected_failure=not has_internal_mem_overlap_check) self.binary_check_input_output_mem_overlap(out_op, device, expected_failure=not has_input_output_mem_overlap_check) - def test_ternary_op_mem_overlap(self, device): + @dtypes(torch.double) + def test_ternary_op_mem_overlap(self, device, dtype): ops = [ ("addcmul", True, True, 'cpu'), ("addcmul", True, True, 'cuda'), @@ -11689,24 +11814,26 @@ def test_ternary_op_mem_overlap(self, device): out_op = getattr(torch, fn) inplace_op = getattr(torch.Tensor, fn + '_') self.check_internal_mem_overlap( - inplace_op, num_inputs=3, device=device, + inplace_op, 3, dtype, device, expected_failure=not has_internal_mem_overlap_check) self.ternary_check_input_output_mem_overlap(out_op, dev, expected_failure=not has_input_output_mem_overlap_check) - def test_copy_mem_overlap(self, device): + @dtypes(torch.double) + def test_copy_mem_overlap(self, device, dtype): self.check_internal_mem_overlap( - torch.Tensor.copy_, num_inputs=2, device=device) + torch.Tensor.copy_, num_inputs=2, dtype=dtype, device=device) sz = 3 - doubles = torch.randn(2 * sz, device=device) + doubles = torch.randn(2 * sz, dtype=dtype, device=device) self.unary_check_input_output_mem_overlap( doubles, sz, lambda input, out: out.copy_(input)) - def test_pow_scalar_overloads_mem_overlap(self, device): + @dtypes(torch.double) + def test_pow_scalar_overloads_mem_overlap(self, device, dtype): sz = 3 - doubles = torch.randn(2 * sz, device=device) + doubles = torch.randn(2 * sz, dtype=dtype, device=device) self.check_internal_mem_overlap( - lambda t: t.pow_(42), num_inputs=1, device=device) + lambda t: t.pow_(42), 1, dtype, device) self.unary_check_input_output_mem_overlap( doubles, sz, lambda input, out: torch.pow(input, 42, out=out)) self.unary_check_input_output_mem_overlap( @@ -11794,7 +11921,10 @@ def test_var_mean_some_dims(self, device): # passes on ROCm w/ python 2.7, fails w/ python 3.6 @skipCUDAIfRocm - def test_stft(self, device): + # stft -> rfft -> _fft -> _fft_with_size -> _fft_mkl + @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") + @dtypes(torch.double) + def test_stft(self, device, dtype): if not TEST_LIBROSA: raise unittest.SkipTest('librosa not found') @@ -11817,9 +11947,9 @@ def librosa_stft(x, n_fft, hop_length, win_length, window, center): def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, center=True, expected_error=None): - x = torch.randn(*sizes, device=device) + x = torch.randn(*sizes, dtype=dtype, device=device) if win_sizes is not None: - window = torch.randn(*win_sizes, device=device) + window = torch.randn(*win_sizes, dtype=dtype, device=device) else: window = None if expected_error is None: @@ -12402,6 +12532,36 @@ def test_memory_format_preserved_after_permute(self, device): y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2) self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) + def test_memory_format_clone(self, device): + nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last) + # nhwc is not memory dense, but looks like channels last + nhwc = nhwc[:, :, ::2, ::2] + clone = nhwc.clone(memory_format=torch.preserve_format) + self.assertFalse(clone.is_contiguous()) + self.assertTrue(clone.is_contiguous(memory_format=torch.channels_last)) + self.assertFalse(nhwc.is_contiguous()) + self.assertFalse(nhwc.is_contiguous(memory_format=torch.channels_last)) + self.assertEqual(nhwc, clone) + + nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last) + clone = nhwc.clone(memory_format=torch.contiguous_format) + self.assertTrue(clone.is_contiguous()) + self.assertFalse(clone.is_contiguous(memory_format=torch.channels_last)) + self.assertEqual(nhwc, clone) + + nhwc = torch.randn((10, 3, 32, 32), device=device).contiguous(memory_format=torch.channels_last) + clone = nhwc.clone() + self.assertTrue(clone.is_contiguous()) + self.assertFalse(clone.is_contiguous(memory_format=torch.channels_last)) + self.assertEqual(nhwc, clone) + + x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device) + for _ in range(10): + permutation = list(range(len(x.shape))) + random.shuffle(permutation) + x = x.permute(permutation) + self.assertEqual(x.stride(), x.clone(memory_format=torch.preserve_format).stride()) + def test_memory_format_empty_like(self, device): x = torch.randn(10, 3, 32, 32, device=device) nhwc = x.contiguous(memory_format=torch.channels_last) @@ -12672,62 +12832,11 @@ def run_subtest(matrix_size, batches, device, pivot): @skipCPUIfNoLapack @skipCUDAIfNoMagma - def test_lu_solve(self, device, pivot=True): - from common_utils import lu_solve_test_helper - - def cast(t): - return t.to(device) - - def sub_test(pivot): - for k, n in zip([2, 3, 5], [3, 5, 7]): - b, A, LU_data, LU_pivots = lu_solve_test_helper(self, (n,), (n, k), cast, pivot) - x = torch.lu_solve(b, LU_data, LU_pivots) - self.assertLessEqual(b.dist(A.mm(x)), 1e-12) - - sub_test(True) - - if self.device_type == 'cuda': - sub_test(False) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - def test_lu_solve_batched(self, device): - from common_utils import lu_solve_test_helper - - def cast(t): - return t.to(device) - - def sub_test(pivot): - def lu_solve_batch_test_helper(A_dims, b_dims, cast, pivot): - b, A, LU_data, LU_pivots = lu_solve_test_helper(self, A_dims, b_dims, cast, pivot) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output - self.assertEqual(x_exp, x_act) # Equality check - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check - - for batchsize in [1, 3, 4]: - lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), cast, pivot) - - # Tests tensors with 0 elements - b = torch.randn(3, 0, 3, device=device) - A = torch.randn(3, 0, 0, device=device) - LU_data, LU_pivots = torch.lu(A) - self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) - - sub_test(True) - - if self.device_type == 'cuda': - sub_test(False) - - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - def test_lu_unpack(self, device, pivot=True): + @dtypes(torch.double) + def test_lu_unpack(self, device, dtype): def run_test(pivot): for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)): - a = torch.randn(*shape, device=device) + a = torch.randn(*shape, dtype=dtype, device=device) a_lu, p = torch.lu(a, pivot=pivot) p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p) self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a) @@ -12751,33 +12860,6 @@ def test_min_with_inf(self, device, dtype): self.assertTrue(torch.all(torch.min(a, dim=1)[0] == (-inf)).item()) self.assertTrue(torch.min(a).item() == -inf) - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - def test_triangular_solve_batched(self, device): - from common_utils import triangular_solve_test_helper - - def cast(t): - return t.to(device) - - def triangular_solve_batch_helper(A_dims, b_dims, cast, upper, unitriangular, transpose): - b, A = triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, - unitriangular=unitriangular, transpose=transpose)[0]) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.triangular_solve(b, A, upper=upper, - unitriangular=unitriangular, transpose=transpose)[0] # Actual output - self.assertEqual(x_act, x_exp) # Equality check - if transpose: - self.assertLessEqual(b.dist(torch.matmul(A.transpose(-2, -1), x_act)), 3e-12) # Correctness check - else: - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 3e-12) # Correctness check - - for (upper, unitriangular, transpose), batchsize in product(product([True, False], repeat=3), [1, 3, 4]): - triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), cast, - upper, unitriangular, transpose) - def test_bincount(self, device): # negative input throws with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): @@ -12997,21 +13079,23 @@ def test_linspace(self, device): b = torch.linspace(0, 10, 10) self.assertEqual(a, b.to(device)) - def test_logspace(self, device): - a = torch.logspace(1, 10, 10, device=device) - b = torch.logspace(1, 10, 10) + @dtypes(torch.double) + def test_logspace(self, device, dtype): + a = torch.logspace(1, 10, 10, dtype=dtype, device=device) + b = torch.logspace(1, 10, 10, dtype=dtype, device='cpu') self.assertEqual(a, b.to(device)) # Check non-default base=2 - a = torch.logspace(1, 10, 10, 2, device=device) - b = torch.logspace(1, 10, 10, 2) + a = torch.logspace(1, 10, 10, 2, dtype=dtype, device=device) + b = torch.logspace(1, 10, 10, 2, dtype=dtype, device='cpu') self.assertEqual(a, b.to(device)) # Note: ROCm fails when using float tensors - def test_polygamma(self, device): - cpu_tensor = torch.randn(10, 10, 10) + @dtypes(torch.double) + def test_polygamma(self, device, dtype): + cpu_tensor = torch.randn(10, 10, 10, dtype=dtype) device_tensor = cpu_tensor.to(device) - zeros = torch.zeros(10, 10, 10) + zeros = torch.zeros(10, 10, 10, dtype=dtype) for n in [0, 1]: cpu_out = cpu_tensor.polygamma(n) device_out = device_tensor.polygamma(n) @@ -13019,10 +13103,11 @@ def test_polygamma(self, device): self.assertEqual(norm_errors, zeros) # Note: fails when using float tensors - def test_digamma(self, device): - cpu_tensor = torch.randn(10, 10, 10) + @dtypes(torch.double) + def test_digamma(self, device, dtype): + cpu_tensor = torch.randn(10, 10, 10, dtype=dtype) device_tensor = cpu_tensor.to(device) - zeros = torch.zeros(10, 10, 10) + zeros = torch.zeros(10, 10, 10, dtype=dtype) cpu_out = cpu_tensor.digamma() device_out = device_tensor.digamma() norm_errors = (device_out - cpu_out.to(device)) / device_out @@ -13031,8 +13116,8 @@ def test_digamma(self, device): # Tests pole behavior cpu_tensor = torch.tensor([-0.999999994, -1.999999994, -2.0000000111, -100.99999994, -1931.99999994, 0.000000111, - -0.000000111, 0, -1, -2, -931]) - expected_errors = torch.tensor([0, 0, 0, 0, 0, 0, 0, nan, nan, nan, nan]) + -0.000000111, 0, -1, -2, -931], dtype=dtype) + expected_errors = torch.tensor([0, 0, 0, 0, 0, 0, 0, nan, nan, nan, nan], dtype=dtype) device_tensor = cpu_tensor.to(device) cpu_out = cpu_tensor.digamma() device_out = device_tensor.digamma() @@ -13078,8 +13163,9 @@ def test_arange(self, device, dtype): self.assertEqual(cpu_tensor, device_tensor) @skipCUDAIfRocm - def test_sum_noncontig(self, device): - x = torch.randn(1, 75, 57, 20, device=device).permute(0, 3, 1, 2) + @dtypes(torch.double) + def test_sum_noncontig(self, device, dtype): + x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2) y = x.cpu() self.assertEqual(x.sum().cpu(), y.sum()) self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2))) @@ -13315,12 +13401,542 @@ def test_ones_like_multiple_device(self, devices): self.assertEqual(output, expected) -add_neg_dim_tests() -instantiate_device_type_tests(TestTorchDeviceType, globals()) -instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') +# Below are fixtures and functions that generate tensor op comparison tests +# These tests run a single op on both a CPU and device tensor and compare the +# the results. In-place variants of the ops can also be run. + +# Lists of dtypes to instantiate tensor op test variants. +_types = [ + torch.half, torch.float, torch.double, + torch.int8, torch.short, torch.int, torch.long, + torch.uint8 +] + +_float_types = [torch.half, torch.float, torch.double] + +_float_types_no_half = [torch.float, torch.double] + +_signed_types = [ + torch.half, torch.float, torch.double, + torch.int8, torch.short, torch.int, torch.long +] + +_signed_types_no_half = [ + torch.float, torch.double, + torch.int8, torch.short, torch.int, torch.long +] + +_unsigned_types = [torch.uint8] + +# Helper values and functions for producing tensors and scalars to use in tensor op tests. +# Tensor dimension sizes (Small, Medium, Large, Giant) +_S = 5 +_M = 50 +_L = 1000 +_G = 275000000 + +# Value to clamp divisors to since dividing by small numbers can be unstable +# on devices. +_div_min = 2**-8 + +# Returns floating or integral scalar corresponding to dtype +def _number(floating, integer, dtype): + if dtype in [torch.half, torch.float, torch.double]: + return floating + return integer + +# Converts half dtype to float when device is cpu +def _convert_t(dtype, device): + if device == 'cpu' and dtype == torch.half: + return torch.float + return dtype + +# Returns a tensor of the requested shape, dtype, and device +# Requesting a half CPU tensor returns a float CPU tensor with +# values representable by a half. +# Initialization uses randint for non-float types and randn for float types. +def _make_tensor(shape, dtype, device, fill_ones=False): + # Returns a tensor filled with ones + if fill_ones: + return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) + + # Returns a tensor with random integer values + if dtype not in _float_types: + t = torch.randint(0, 10, shape, device=device) + return t.to(_convert_t(dtype, device)) + + # Populates the CPU tensor with floats representable as halfs + if dtype == torch.half and device == 'cpu': + return torch.randn(*shape, dtype=torch.float, device=device).half().float() + + # Default: returns a tensor with random float values + return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) + +def _small_0d(dtype, device): + return _make_tensor((1,), dtype, device).squeeze() + +def _small_2d(dtype, device, has_zeros=True, fill_ones=False, oneish=False): + t = _make_tensor((_S, _S), dtype, device, fill_ones=fill_ones) + if oneish: + return t.clamp(min=_number(.99, 1, dtype), max=1.01) + if not has_zeros: + return t.clamp(min=(_number(_div_min, 1, dtype))) + return t + +def _small_3d(dtype, device, has_zeros=True, fill_ones=False, oneish=False): + t = _make_tensor((_S, _S, _S), dtype, device, fill_ones=fill_ones) + if oneish: + return t.clamp(min=_number(.99, 1, dtype), max=1.01) + if not has_zeros: + return t.clamp(min=(_number(_div_min, 1, dtype))) + return t + +def _small_3d_ones(dtype, device): + return _small_3d(dtype, device, fill_ones=True) + +def _small_3d_unique(dtype, device): + return (torch.randperm(_S * _S * _S, + dtype=_convert_t(dtype, device), device=device) + 1).view(_S, _S, _S) + +def _medium_1d(dtype, device): + return _make_tensor((_M,), dtype, device) + +def _medium_2d(dtype, device): + return _make_tensor((_M, _M), dtype, device) + +def _large_2d(dtype, device): + t = _make_tensor((_L, _L), dtype, device) + return t.normal_() + +def _giant_1d(dtype, device): + return _make_tensor((_G), dtype, device) + +# Helper method that returns a function which takes dtype and device and +# instantiates tensors of the given shape. +# Useful for tensor op tests with custom shapes. +def _new_t(shape): + def tmp(dtype, device): + return _make_tensor(shape, dtype, device) + return tmp + +# TODO: random functions, cat, gather, scatter, index*, masked*, +# resize, resizeAs, storage_offset, storage, stride, unfold +# Each tests is defined in tensor_op_tests as a tuple of: +# - op name (string) +# - (sub)test name (string) +# - tensor constructor, takes dtype and device and constructs the tensor to run the op on +# - arg constructor, takes dtype and device and constructs op arguments +# - torch.half precision (=1e-5) +# - precision (=1e-5), precision to use for all other dtypes +# - make_inplace_variant (=True), if true the inplace version of the op (op_) is also tested +# - dtype_list (=_types), a list of torch dtypes to test the op(s) with +# - decorators (=[]), a list of decorators to apply to the test +tensor_op_tests = [ + ('add', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-2), + ('add', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2), + ('sub', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-2), + ('sub', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2), + ('mul', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-2), + ('mul', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2), + ('mul', 'scalar', _small_0d, lambda t, d: [_small_0d(torch.int32, d)], 1e-2), + ('div', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1), + ('div', 'tensor', _small_3d, + lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-1), + ('pow', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-1, 1e-5, _float_types), + ('pow', '1', _small_3d, lambda t, d: [_number(1., 1, t)], 1e-1), + ('pow', '2', _small_3d, lambda t, d: [_number(2., 2, t)], 1e-1), + ('pow', '3', _small_3d, lambda t, d: [_number(3., 3, t)], 1e-1), + ('pow', '-1', _small_3d, lambda t, d: [_number(-1., -1, t)], 1e-1, 1e-5, _float_types), + ('pow', '-2', _small_3d, lambda t, d: [_number(-2., -2, t)], + 1e-1, 1e-5, _float_types_no_half, False, [skipCUDAIfRocm]), + ('pow', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d).abs()], + 1e-1, 1e-5, _float_types), + ('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], + 1e-1, 1e-4, _float_types), + ('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], + 1e-1, 1e-4, _float_types), + ('addbmm', 'two_scalars', _small_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], + 1e-1, 1e-4, _float_types), + ('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], + 1e-2, 1e-4, _float_types), + ('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], + 1e-2, 1e-4, _float_types), + ('baddbmm', 'two_scalars', _small_3d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], + 1e-2, 1e-4, _float_types), + ('bmm', '', _small_3d, lambda t, d: [_small_3d(t, d)], + 1e-5, 1e-5, _float_types_no_half, False), + ('addcdiv', '', _small_2d, + lambda t, d: [_small_2d(t, d), + _small_2d(t, d, has_zeros=False)], 1, 1e-3), + ('addcdiv', 'scalar', _small_2d, + lambda t, d: [_number(2.8, 1, t), _small_2d(t, d), + _small_2d(t, d, has_zeros=False)], 1, 1e-3), + ('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-3), + ('addcmul', 'scalar', _small_3d, + lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2), + ('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)], + 1e-1, 1e-4, _float_types), + ('addmm', 'scalar', _medium_2d, + lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], + 1e-1, 1e-4, _float_types), + ('addmm', 'two_scalars', _medium_2d, + lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], + 1e-1, 1e-4, _float_types), + ('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)], + 1e-2, 1e-4, _float_types), + ('addmv', 'scalar', _medium_1d, + lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], + 1e-2, 1e-4, _float_types), + ('addmv', 'two_scalars', _medium_1d, + lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], + 1e-2, 1e-4, _float_types), + ('addr', '', _medium_2d, lambda t, d: [_medium_1d(t, d), _medium_1d(t, d)], + 1e-2, 1e-4, _float_types), + ('addr', 'scalar', _medium_2d, + lambda t, d: [_number(0.4, 2, t), _medium_1d(t, d), _medium_1d(t, d)], + 1e-2, 1e-4, _float_types), + ('addr', 'two_scalars', _medium_2d, + lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_1d(t, d), _medium_1d(t, d)], + 1e-2, 1e-4, _float_types), + ('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, _float_types), + ('fmod', 'value', _small_3d, lambda t, d: [3], 1e-3), + ('fmod', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-3), + ('chunk', '', _medium_2d, lambda t, d: [4], 1e-5, 1e-5, _types, False), + ('chunk', 'dim', _medium_2d, lambda t, d: [4, 1], 1e-5, 1e-5, _types, False), + ('chunk', 'neg_dim', _medium_2d, lambda t, d: [4, -2], 1e-5, 1e-5, _types, False), + ('clamp', 'neg', _medium_2d, lambda t, d: [-1, 5], 1e-5, 1e-5, _signed_types), + ('clamp', 'pos', _medium_2d, lambda t, d: [1, 5], 1e-5, 1e-5, _unsigned_types), + ('clone', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('cross', '', _new_t((_M, 3, _M)), lambda t, d: [_new_t((_M, 3, _M))(t, d)], + 1e-2, 1e-5, _types, False), + ('cumprod', '', _small_3d, lambda t, d: [1], 1e-2, 1e-4, _types, False), + ('cumprod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-4, _types, False), + ('cumsum', '', _small_3d, lambda t, d: [1], 1e-2, 1e-5, _types, False), + ('cumsum', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-5, _types, False), + ('dim', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('dist', '', _small_2d, lambda t, d: [_small_2d(t, d)], 1e-2, 1e-5, _float_types, False), + ('dist', '3_norm', _small_2d, lambda t, d: [_small_2d(t, d), 3], 1e-2, 1e-5, _float_types, False), + ('dist', '2_5_norm', _small_2d, lambda t, d: [_small_2d(t, d), 2.5], + 1e-2, 1e-5, _float_types, False), + ('dot', '', _medium_1d, lambda t, d: [_medium_1d(t, d)], + 1e-2, 1e-5, _float_types, False, [skipCUDAIfRocm]), + ('element_size', '', _medium_1d, lambda t, d: [], 1e-5, 1e-5, _float_types_no_half, False), + ('eq', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)],), + ('eq', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)]), + ('ne', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)],), + ('ne', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)]), + ('equal', 'equal', _small_3d_ones, lambda t, d: [_small_3d_ones(t, d)], + 1e-5, 1e-5, _types, False), + ('equal', '', _small_3d_ones, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, _types, False), + ('expand', '', _new_t((_M, 1, _M)), lambda t, d: [_M, 4, _M], 1e-5, 1e-5, _types, False), + ('expand_as', '', _new_t((_M, 1, _M)), lambda t, d: [_new_t((_M, 4, _M))(t, d)], + 1e-5, 1e-5, _types, False), + ('fill_', '', _medium_2d, lambda t, d: [_number(3.14, 3, t)], 1e-3, 1e-5, _types, False), + ('ge', '', _medium_2d, lambda t, d: [_medium_2d(t, d)],), + ('le', '', _medium_2d, lambda t, d: [_medium_2d(t, d)],), + ('gt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)],), + ('lt', '', _medium_2d, lambda t, d: [_medium_2d(t, d)],), + ('is_contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, _types, False), + # TODO: can't check negative case - cross-device copy is contiguous + ('is_same_size', 'negative', _medium_2d, lambda t, d: [_small_3d(t, d)], + 1e-5, 1e-5, _types, False), + ('is_same_size', 'positive', _medium_2d, lambda t, d: [_medium_2d(t, d)], + 1e-5, 1e-5, _types, False), + ('is_set_to', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-5, 1e-5, _types, False), + # TODO: positive case + ('kthvalue', '', _small_3d_unique, lambda t, d: [3], 1e-5, 1e-5, _types, False), + ('kthvalue', 'dim', _small_3d_unique, lambda t, d: [3, 1], 1e-5, 1e-5, _types, False), + ('kthvalue', 'neg_dim', _small_3d_unique, lambda t, d: [3, -1], 1e-5, 1e-5, _types, False), + ('lerp', '', _small_3d, lambda t, d: [_small_3d(t, d), 0.3], + 1e-2, 1e-5, _float_types_no_half), + ('max', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('max', 'dim', _small_3d_unique, lambda t, d: [1], 1e-5, 1e-5, _types, False), + ('max', 'neg_dim', _small_3d_unique, lambda t, d: [-1], 1e-5, 1e-5, _types, False), + ('max', 'elementwise', _medium_2d, lambda t, d: [_medium_2d(t, d)], + 1e-5, 1e-5, _types, False), + ('min', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('min', 'dim', _small_3d_unique, lambda t, d: [1], 1e-5, 1e-5, _types, False), + ('min', 'neg_dim', _small_3d_unique, lambda t, d: [-1], 1e-5, 1e-5, _types, False), + ('min', 'elementwise', _medium_2d, lambda t, d: [_medium_2d(t, d)], + 1e-5, 1e-5, _types, False), + ('mean', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types, False), + ('mean', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, _float_types, False), + ('mean', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, _float_types, False), + # Double here because the CPU result will be wrong otherwise + ('mean', '64bit_indexing', _giant_1d, lambda t, d: [], + 1e-3, 1e-5, [torch.double], False, [slowTest]), + ('mode', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('mode', 'dim', _small_3d, lambda t, d: [1], 1e-5, 1e-5, _types, False), + ('mode', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-5, 1e-5, _types, False), + ('mvlgamma', '2d_p=1', lambda t, d: _small_2d(t, d).clamp(0.1, 10), lambda t, d: [1], + 1e-5, 1e-5, _float_types_no_half), + ('mvlgamma', '2d_p=2', lambda t, d: _small_2d(t, d).clamp(0.6, 10), lambda t, d: [2], + 1e-5, 1e-5, _float_types_no_half), + ('remainder', 'value', _small_3d, lambda t, d: [3], 1e-1, 1e-5, _signed_types), + ('remainder', 'negative_value', _small_3d, lambda t, d: [-3], 1e-1, 1e-5, _signed_types), + ('remainder', 'tensor', _small_3d, + lambda t, d: [_small_3d(t, d, has_zeros=False)], + 1e-1, 1e-5, _signed_types), + ('remainder', 'negative_tensor', _small_3d, + lambda t, d: [0 - _small_3d(t, d, has_zeros=False)], + 1e-1, 1e-5, _signed_types), + ('std', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types, False), + ('std', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, _float_types, False), + ('std', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, _float_types, False), + ('var', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types, False), + ('var', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, _float_types, False), + ('var', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, _float_types, False), + ('ndimension', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('nelement', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('numel', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('narrow', '', _small_3d, lambda t, d: [1, 3, 2], 1e-5, 1e-5, _types, False), + ('narrow', 'neg_dim', _small_3d, lambda t, d: [-1, 3, 2], 1e-5, 1e-5, _types, False), + ('nonzero', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('norm', '', _small_3d, lambda t, d: [], 1e-1, 1e-5, _float_types, False), + ('norm', '3_norm', _small_3d, lambda t, d: [3], 1e-1, 1e-5, _float_types, False), + ('norm', '3_norm_dim', _small_3d, lambda t, d: [3, 0], 1e-1, 1e-5, _float_types, False), + ('norm', '3_norm_neg_dim', _small_3d, lambda t, d: [3, -2], 1e-1, 1e-5, _float_types, False), + ('new_ones', '', _small_3d, lambda t, d: [1, 2, 3, 4, 5], 1e-5, 1e-5, _types, False), + ('permute', '', _new_t((1, 2, 3, 4)), lambda t, d: [2, 1, 3, 0], 1e-5, 1e-5, _types, False), + ('put_', '', _new_t((2, 5, 3)), + lambda t, d: [torch.LongTensor([[0], [-2]]).to(device=d), + torch.LongTensor([[3], [4]]).to(dtype=_convert_t(t, d), device=d)], + 1e-5, 1e-5, _types, False), + ('put_', 'empty', _new_t((2, 3)), + lambda t, d: [torch.LongTensor([]).to(device=d), torch.LongTensor([]).to(dtype=_convert_t(t, d), device=d)], + 1e-5, 1e-5, _types, False), + ('put_', 'accumulate', _new_t((2, 2)), + lambda t, d: [torch.LongTensor([[1], [-3]]).to(device=d), + torch.LongTensor([[1], [2]]).to(dtype=_convert_t(t, d), device=d), + True], + 1e-5, 1e-5, _types, False), + ('prod', '', lambda t, d: _small_2d(t, d, oneish=True), + lambda t, d: [], 1e-2, 1e-5, _types, False), + ('prod', 'dim', _small_3d, lambda t, d: [1], 1e-3, 1e-5, _types, False), + ('prod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-3, 1e-5, _types, False), + ('sum', '', _small_2d, lambda t, d: [], 1e-2, 1e-5, _types, False), + ('sum', 'dim', _small_3d, lambda t, d: [1], 1e-2, 1e-5, _types, False), + ('sum', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-5, _types, False), + ('renorm', '2_norm', _small_3d, lambda t, d: [2, 1, 1], 1e-3, 1e-5, _float_types), + ('renorm', '2_norm_neg_dim', _small_3d, lambda t, d: [2, -1, 1], 1e-3, 1e-5, _float_types), + ('renorm', '1_5_norm', _small_3d, lambda t, d: [1.5, 1, 1], 1e-3, 1e-5, _float_types), + ('repeat', '', _small_2d, lambda t, d: [2, 2, 2], 1e-5, 1e-5, _types, False), + ('size', '', _new_t((1, 2, 3, 4)), lambda t, d: [], 1e-5, 1e-5, _types, False), + ('size', 'dim', _new_t((1, 2, 3, 4)), lambda t, d: [1], 1e-5, 1e-5, _types, False), + ('size', 'neg_dim', _new_t((1, 2, 3, 4)), lambda t, d: [-2], 1e-5, 1e-5, _types, False), + ('sort', '', _small_3d_unique, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('sort', 'dim', _small_3d_unique, lambda t, d: [1], 1e-5, 1e-5, _types, False), + ('sort', 'neg_dim', _small_3d_unique, lambda t, d: [-1], 1e-5, 1e-5, _types, False), + ('sort', 'dim_descending', _small_3d_unique, lambda t, d: [1, True], 1e-5, 1e-5, _types, False), + ('sort', 'neg_dim_descending', _small_3d_unique, lambda t, d: [-1, True], 1e-5, 1e-5, _types, False), + ('split', '', _small_3d, lambda t, d: [2], 1e-5, 1e-5, _types, False), + ('split', 'dim', _small_3d, lambda t, d: [2, 1], 1e-5, 1e-5, _types, False), + ('split', 'neg_dim', _small_3d, lambda t, d: [2, -3], 1e-5, 1e-5, _types, False), + ('squeeze', '', _new_t((1, 2, 1, 4)), lambda t, d: [],), + ('squeeze', 'dim', _new_t((1, 2, 1, 4)), lambda t, d: [2], ), + ('squeeze', 'neg_dim', _new_t((1, 2, 1, 4)), lambda t, d: [-2], ), + ('t', '', _new_t((1, 2)), lambda t, d: [],), + ('take', '', _new_t((3, 4)), + lambda t, d: [torch.LongTensor([[0], [-2]]).to(device=d)], + 1e-5, 1e-5, _types, False), + ('transpose', '', _new_t((1, 2, 3, 4)), lambda t, d: [1, 2],), + ('transpose', 'neg_dim', _new_t((1, 2, 3, 4)), lambda t, d: [-1, -2], ), + ('tolist', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('topk', 'dim_sort', _small_3d_unique, lambda t, d: [2, 1, False, True], + 1e-5, 1e-5, _types, False), + ('topk', 'neg_dim_sort', _small_3d_unique, lambda t, d: [2, -1, False, True], + 1e-5, 1e-5, _types, False), + ('topk', 'dim_desc_sort', _small_3d_unique, lambda t, d: [2, 1, True, True], + 1e-5, 1e-5, _types, False), + ('trace', '', _medium_2d, lambda t, d: [], 1e-3, 1e-5, _types, False), + ('tril', '', _medium_2d, lambda t, d: [],), + ('tril', 'zero_stride', _medium_2d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('tril', 'positive', _medium_2d, lambda t, d: [2], ), + ('tril', 'negative', _medium_2d, lambda t, d: [-2], ), + ('triu', '', _medium_2d, lambda t, d: [],), + ('triu', 'zero_stride', _medium_2d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('triu', 'positive', _medium_2d, lambda t, d: [2], ), + ('triu', 'negative', _medium_2d, lambda t, d: [-2], ), + ('unsqueeze', '', _new_t((2, 3, 4)), lambda t, d: [2],), + ('unsqueeze', 'neg_dim', _new_t((2, 3, 4)), lambda t, d: [-2], ), + ('view', 'contiguous', _small_3d, lambda t, d: [25, 5], 1e-5, 1e-5, _types, False), + ('view_as', '', _small_3d, lambda t, d: [_make_tensor((25, 5), t, d)], + 1e-5, 1e-5, _types, False), + ('zero_', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('new_zeros', '', _small_3d, lambda t, d: [1, 2, 3, 4], 1e-5, 1e-5, _types, False), + ('flip', 'd0', _small_3d, lambda t, d: [0], 1e-5, 1e-5, _types, False), + ('flip', 'd012', _small_3d, lambda t, d: [0, 1, 2], 1e-5, 1e-5, _types, False), + ('flip', 'd02', _small_3d, lambda t, d: [0, 2], 1e-5, 1e-5, _types, False), + ('flip', 'd20', _small_3d, lambda t, d: [2, 0], 1e-5, 1e-5, _types, False), + ('flip', 'neg_d', _small_3d, lambda t, d: [-1], 1e-5, 1e-5, _types, False), + ('rot90', 'k1_d01', _small_2d, lambda t, d: [1, [0, 1]], 1e-5, 1e-5, _types, False), + ('rot90', 'k1_d12', _small_3d, lambda t, d: [1, [1, 2]], 1e-5, 1e-5, _types, False), + ('rot90', 'k1_neg_d', _small_3d, lambda t, d: [1, [1, -1]], 1e-5, 1e-5, _types, False), + ('rot90', 'default', _small_3d, lambda t, d: [], 1e-5, 1e-5, _types, False), + ('rsqrt', '', lambda t, d: _small_3d(t, d) + 1, lambda t, d: [], 1e-2, 1e-4, _float_types_no_half), + ('sinh', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, _float_types), + ('tan', '', lambda t, d: _small_3d(t, d).clamp(-1, 1), lambda t, d: [], 1e-3, 1e-5, _float_types), + ('__lshift__', '', + lambda t, d: torch.pow(2, torch.arange(1, 5).to(dtype=_convert_t(t, d), device=d)), + lambda t, d: [2], + 1e-3, 1e-3, _signed_types_no_half, False), + ('__rshift__', '', + lambda t, d: torch.pow(2, torch.arange(3, 7).to(dtype=_convert_t(t, d), device=d)), + lambda t, d: [2], + 1e-3, 1e-3, _signed_types_no_half, False), + # lapack tests + ('qr', 'square', _small_2d, lambda t, d: [], + 1e-5, 3e-4, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('qr', 'skinny', _new_t((3, 4)), lambda t, d: [], + 1e-5, 3e-4, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('qr', 'fat', _new_t((4, 3)), lambda t, d: [], + 1e-5, 3e-4, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('qr', 'big', _large_2d, lambda t, d: [], + 1e-5, 3e-4, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('geqrf', '', _new_t((20, 20)), lambda t, d: [], + 1e-5, 3e-4, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('svd', 'square', _new_t((10, 10)), lambda t, d: [], + 1e-5, 1e-5, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('svd', 'square_col_maj', lambda t, d: _new_t((10, 10))(t, d).t(), lambda t, d: [True], + 1e-5, 1e-5, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('svd', 'tall_some', _new_t((20, 5)), lambda t, d: [True], + 1e-5, 1e-5, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('svd', 'tall_all', _new_t((20, 5)), lambda t, d: [False], + 1e-5, 1e-5, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('svd', 'tall_some_col_maj', lambda t, d: _new_t((5, 20))(t, d).t(), lambda t, d: [True], + 1e-5, 1e-5, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('svd', 'tall_all_col_maj', lambda t, d: _new_t((5, 20))(t, d).t(), lambda t, d: [False], + 1e-5, 1e-5, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('eig', 'with_eigvec', _new_t((10, 10)), lambda t, d: [True], + 1e-5, 1e-5, _float_types_no_half, False, [skipCUDAIfNoMagma]), + ('abs', '', _small_3d, lambda t, d: []), + ('sign', '', _small_3d, lambda t, d: []), + ('log', '', _small_3d, lambda t, d: [], 1e-2, 1e-5, _float_types), + ('log10', '', _small_3d, lambda t, d: [], 1e-2, 1e-5, _float_types), + ('log1p', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types_no_half), + ('log2', '', _small_3d, lambda t, d: [], 1e-2, 1e-5, _float_types), + ('sigmoid', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('sin', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('sqrt', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('tanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('acos', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('asin', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('atan', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('cos', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('cosh', '', _small_3d, lambda t, d: [], 1e-2, 1e-5, _float_types), + ('erf', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('erfc', '', _small_3d, lambda t, d: [], 1e-3, 1e-5, _float_types), + ('exp', '', _small_3d, lambda t, d: [], 1e-2, 1e-5, _float_types), + ('expm1', '', _small_3d, lambda t, d: [], 1e-2, 1e-5, _float_types), + ('reciprocal', '', _small_3d, lambda t, d: [], 1e-1, 1e-5, _float_types), + ('floor', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _float_types), + ('frac', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _float_types), + ('neg', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _float_types), + ('round', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _float_types), + ('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _float_types), + ('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, _float_types), + ('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-5, _float_types_no_half), + ('digamma', 'op', _small_3d, lambda t, d: [], 1e-5, 1e0, _float_types_no_half), +] + +# Creates and decorates a generic test and adds it to the class. +def generate_test_function(cls, + op_str, + subtest_str, + tensor_ctor, + arg_ctor, + half_precision, + float_precision, + dtype_list, + decorators): + def fn(self, device, dtype): + # Generates the CPU inputs + # Note: CPU tensors are never torch.half + cpu_tensor = tensor_ctor(dtype, 'cpu') + cpu_args = arg_ctor(dtype, 'cpu') + + # Converts CPU tensors to device tensors + device_tensor = cpu_tensor.to(dtype=dtype, device=device) + device_args = [arg.to(device=device) if torch.is_tensor(arg) else arg for arg in cpu_args] + + # Converts float device tensors to half when the dtype is half + # Note: CPU half tensors don't support many operations. + if dtype == torch.half: + device_args = [arg.to(dtype=dtype) if + (torch.is_tensor(arg) and arg.dtype == torch.float) else arg + for arg in device_args] + + # Runs the tensor op on CPU and device + cpu_result = getattr(cpu_tensor, op_str)(*cpu_args) + device_result = getattr(device_tensor, op_str)(*device_args) + + # Compares CPU and device inputs and outputs + precision = half_precision if dtype == torch.half else float_precision + + self.assertEqual(cpu_tensor, device_tensor, prec=precision) + self.assertEqual(cpu_args, device_args, prec=precision) + self.assertEqual(cpu_result, device_result, prec=precision) + + test_name = "test_" + op_str + subtest_str + assert not hasattr(cls, test_name), "{0} already in TestDevicePrecision".format(test_name) + + # Constructs decorator list and applies decorators + if decorators is None: + decorators = [dtypes(*dtype_list)] + else: + decorators = decorators + [dtypes(*dtype_list)] + + for dec in decorators: + fn = dec(fn) + + setattr(cls, test_name, fn) + +# Instantiates variants of tensor_op_tests and adds them to the given class. +def generate_tensor_op_tests(cls): + + def caller(cls, + op_str, + subtest_str, + tensor_ctor, + arg_ctor, + half_precision=1e-5, + float_precision=1e-5, + dtype_list=_types, + make_inplace_variant=True, + decorators=None): + if subtest_str: + subtest_str = '_' + subtest_str + + generate_test_function(cls, op_str, subtest_str, tensor_ctor, arg_ctor, + half_precision, float_precision, dtype_list, decorators) + + if make_inplace_variant: + op_str = op_str + '_' + subtest_str = 'inplace' + subtest_str + generate_test_function(cls, op_str, subtest_str, tensor_ctor, arg_ctor, + half_precision, float_precision, dtype_list, decorators) + + for test in tensor_op_tests: + caller(cls, *test) + + +class TestTensorDeviceOps(TestCase): + pass + class TestTorch(TestCase, _TestTorchMixin): pass + +# Generates tests +# Note: test generation must be done at file scope, not within main, or +# pytest will fail. +add_neg_dim_tests() +generate_tensor_op_tests(TestTensorDeviceOps) +instantiate_device_type_tests(TestTorchDeviceType, globals()) +instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') +instantiate_device_type_tests(TestTensorDeviceOps, globals(), except_for='cpu') + if __name__ == '__main__': run_tests() diff --git a/test/test_utils.py b/test/test_utils.py index 619cb626214f9..df032d2f7a5ac 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -570,5 +570,10 @@ def test_load_zip_checkpoint(self): SUM_OF_HUB_EXAMPLE) +class TestHipify(TestCase): + def test_import_hipify(self): + from torch.utils.hipify import hipify_python # noqa + + if __name__ == '__main__': run_tests() diff --git a/third_party/fbgemm b/third_party/fbgemm index c8b854042b364..82d259dade58e 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit c8b854042b364080af3c45da0db27cccbbe9b219 +Subproject commit 82d259dade58e53775a534f88b7b48e760f09a64 diff --git a/third_party/miniz-2.0.8/miniz.h b/third_party/miniz-2.0.8/miniz.h index 751c0f3754aa2..67b533079f255 100755 --- a/third_party/miniz-2.0.8/miniz.h +++ b/third_party/miniz-2.0.8/miniz.h @@ -125,7 +125,7 @@ /* If MINIZ_NO_TIME is specified then the ZIP archive functions will not be able to get the current time, or */ /* get/set file times, and the C run-time funcs that get/set times won't be called. */ /* The current downside is the times written to your archives will be from 1979. */ -/*#define MINIZ_NO_TIME */ +#define MINIZ_NO_TIME /* Define MINIZ_NO_ARCHIVE_APIS to disable all ZIP archive API's. */ /*#define MINIZ_NO_ARCHIVE_APIS */ diff --git a/third_party/onnx b/third_party/onnx index 034921bd574cc..2891e14597459 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 034921bd574cc84906b7996c07873454b7dd4135 +Subproject commit 2891e1459745933f4bba9a8cb3371cf3c9eb1d16 diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 3643ab931eef8..60f85268bad8f 100644 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -4,10 +4,16 @@ import os import subprocess import argparse -from functools import reduce -from itertools import chain +import sys +sys.path.append(os.path.realpath(os.path.join( + __file__, + os.path.pardir, + os.path.pardir, + os.path.pardir, + 'torch', + 'utils'))) -from pyHIPIFY import hipify_python +from hipify import hipify_python parser = argparse.ArgumentParser(description='Top-level script for HIPifying, filling in most common parameters') parser.add_argument( @@ -97,7 +103,6 @@ '*/hip/*', # These files are compatible with both cuda and hip "aten/src/ATen/core/*", - "torch/csrc/autograd/engine.cpp", # generated files we shouldn't frob "torch/lib/tmp_install/*", "torch/include/*", @@ -109,33 +114,6 @@ for filename in os.listdir(os.path.join(amd_build_dir, "patches")): subprocess.Popen(["git", "apply", os.path.join(patch_folder, filename)], cwd=proj_dir) - # Make various replacements inside AMD_BUILD/torch directory - ignore_files = [ - # These files use nvrtc, hip doesn't have equivalent - "csrc/autograd/profiler.h", - "csrc/autograd/profiler.cpp", - # These files are compatible with both cuda and hip - "csrc/autograd/engine.cpp" - ] - paths = ("torch", "tools") - for root, _directories, files in chain.from_iterable(os.walk(path) for path in paths): - for filename in files: - if filename.endswith(".cpp") or filename.endswith(".h") or filename.endswith(".hpp"): - source = os.path.join(root, filename) - # Disabled files - if reduce(lambda result, exclude: source.endswith(exclude) or result, ignore_files, False): - continue - # Update contents. - with open(source, "r+") as f: - contents = f.read() - contents = contents.replace("USE_CUDA", "USE_ROCM") - contents = contents.replace("CUDA_VERSION", "0") - f.seek(0) - f.write(contents) - f.truncate() - f.flush() - os.fsync(f) - # Check if the compiler is hip-clang. def is_hip_clang(): try: diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py deleted file mode 100644 index f34d5a0f38d3a..0000000000000 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ /dev/null @@ -1,2407 +0,0 @@ -#!/usr/bin/env python3 -import collections - -from pyHIPIFY.constants import * - -""" Mapping of CUDA functions, include files, constants, and types to ROCm/HIP equivalents -This closely follows the implementation in hipify-clang -https://github.com/ROCm-Developer-Tools/HIP/blob/master/hipify-clang/src/CUDA2HipMap.cpp -and its structure. -There are different maps for fundamental names, include files, identifies, sparse, and -PyTorch specific translations. -Each of the entries in these maps translates a CUDA string to a tuple containing the -ROCm/HIP string, a type and API annotation and - optionally - an annotation if it is not -supported in ROCm/HIP yet. -""" - -# List of math functions that should be replaced inside device code only. -MATH_TRANSPILATIONS = collections.OrderedDict([ - ("std::max", ("::max")), - ("std::min", ("::min")), - ("std::ceil", ("::ceil")), - ("std::floor", ("::floor")), - ("std::exp", ("::exp")), - ("std::log", ("::log")), - ("std::pow", ("::pow")), - ("std::fabs", ("::fabs")), - ("std::fmod", ("::fmod")), - ("std::remainder", ("::remainder")), -]) - -CUDA_TYPE_NAME_MAP = collections.OrderedDict([ - ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), - ("cudaError_t", ("hipError_t", CONV_TYPE, API_RUNTIME)), - ("CUDA_ARRAY3D_DESCRIPTOR", ("HIP_ARRAY3D_DESCRIPTOR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUDA_ARRAY_DESCRIPTOR", ("HIP_ARRAY_DESCRIPTOR", CONV_TYPE, API_DRIVER)), - ("CUDA_MEMCPY2D", ("hip_Memcpy2D", CONV_TYPE, API_DRIVER)), - ("CUDA_MEMCPY3D", ("HIP_MEMCPY3D", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUDA_MEMCPY3D_PEER", ("HIP_MEMCPY3D_PEER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUDA_POINTER_ATTRIBUTE_P2P_TOKENS", ("HIP_POINTER_ATTRIBUTE_P2P_TOKENS", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUDA_RESOURCE_DESC", ("HIP_RESOURCE_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUDA_RESOURCE_VIEW_DESC", ("HIP_RESOURCE_VIEW_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUipcEventHandle", ("hipIpcEventHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUipcMemHandle", ("hipIpcMemHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUaddress_mode", ("hipAddress_mode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUarray_cubemap_face", ("hipArray_cubemap_face", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUarray_format", ("hipArray_format", CONV_TYPE, API_DRIVER)), - ("CUcomputemode", ("hipComputemode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUmem_advise", ("hipMemAdvise", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUmem_range_attribute", ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUctx_flags", ("hipCctx_flags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUdevice", ("hipDevice_t", CONV_TYPE, API_DRIVER)), - ("CUdevice_attribute_enum", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), - ("CUdevice_attribute", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), - ("CUdeviceptr", ("hipDeviceptr_t", CONV_TYPE, API_DRIVER)), - ("CUarray_st", ("hipArray", CONV_TYPE, API_DRIVER)), - ("CUarray", ("hipArray *", CONV_TYPE, API_DRIVER)), - ("CUdevprop_st", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), - ("CUdevprop", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), - ("CUfunction", ("hipFunction_t", CONV_TYPE, API_DRIVER)), - ("CUgraphicsResource", ("hipGraphicsResource_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUmipmappedArray", ("hipMipmappedArray_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUfunction_attribute", ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUfunction_attribute_enum", ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUgraphicsMapResourceFlags", ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUgraphicsMapResourceFlags_enum", ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUgraphicsRegisterFlags", ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUgraphicsRegisterFlags_enum", ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUoccupancy_flags", ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUoccupancy_flags_enum", ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUfunc_cache_enum", ("hipFuncCache", CONV_TYPE, API_DRIVER)), - ("CUfunc_cache", ("hipFuncCache", CONV_TYPE, API_DRIVER)), - ("CUipcMem_flags", ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUipcMem_flags_enum", ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjit_cacheMode", ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjit_cacheMode_enum", ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjit_fallback", ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjit_fallback_enum", ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjit_option", ("hipJitOption", CONV_JIT, API_DRIVER)), - ("CUjit_option_enum", ("hipJitOption", CONV_JIT, API_DRIVER)), - ("CUjit_target", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjit_target_enum", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjitInputType", ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUjitInputType_enum", ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUlimit", ("hipLimit_t", CONV_TYPE, API_DRIVER)), - ("CUlimit_enum", ("hipLimit_t", CONV_TYPE, API_DRIVER)), - ("CUmemAttach_flags", ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUmemAttach_flags_enum", ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUmemorytype", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUmemorytype_enum", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUresourcetype", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), - ("CUresourcetype_enum", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), - ("CUresourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), - ("CUresourceViewFormat_enum", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), - ("CUsharedconfig", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), - ("CUsharedconfig_enum", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), - ("CUcontext", ("hipCtx_t", CONV_TYPE, API_DRIVER)), - ("CUmodule", ("hipModule_t", CONV_TYPE, API_DRIVER)), - ("CUstream", ("hipStream_t", CONV_TYPE, API_DRIVER)), - ("CUstream_st", ("ihipStream_t", CONV_TYPE, API_DRIVER)), - ("CUstreamCallback", ("hipStreamCallback_t", CONV_TYPE, API_DRIVER)), - ("CUsurfObject", ("hipSurfaceObject", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUsurfref", ("hipSurfaceReference_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUtexObject", ("hipTextureObject_t", CONV_TYPE, API_DRIVER)), - ("CUtexref", ("textureReference", CONV_TYPE, API_DRIVER)), - ("CUstream_flags", ("hipStreamFlags", CONV_TYPE, API_DRIVER)), - ("CUstreamWaitValue_flags", ("hipStreamWaitValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUstreamWriteValue_flags", ("hipStreamWriteValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUstreamBatchMemOpType", ("hipStreamBatchMemOpType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUdevice_P2PAttribute", ("hipDeviceP2PAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), - ("CUevent", ("hipEvent_t", CONV_TYPE, API_DRIVER)), - ("CUevent_st", ("ihipEvent_t", CONV_TYPE, API_DRIVER)), - ("CUevent_flags", ("hipEventFlags", CONV_EVENT, API_DRIVER, HIP_UNSUPPORTED)), - ("CUfilter_mode", ("hipTextureFilterMode", CONV_TEX, API_DRIVER)), - ("CUGLDeviceList", ("hipGLDeviceList", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), - ("CUGLmap_flags", ("hipGLMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), - ("CUd3d9DeviceList", ("hipD3D9DeviceList", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED)), - ("CUd3d9map_flags", ("hipD3D9MapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED)), - ("CUd3d9register_flags", ("hipD3D9RegisterFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED)), - ("CUd3d10DeviceList", ("hipd3d10DeviceList", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED)), - ("CUd3d10map_flags", ("hipD3D10MapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED)), - ("CUd3d10register_flags", ("hipD3D10RegisterFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED)), - ("CUd3d11DeviceList", ("hipd3d11DeviceList", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED)), - ("CUeglStreamConnection_st", ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED)), - ("CUeglStreamConnection", ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED)), - ("libraryPropertyType_t", ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), - ("libraryPropertyType", ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaStreamCallback_t", ("hipStreamCallback_t", CONV_TYPE, API_RUNTIME)), - ("cudaArray", ("hipArray", CONV_MEM, API_RUNTIME)), - ("cudaArray_t", ("hipArray_t", CONV_MEM, API_RUNTIME)), - ("cudaArray_const_t", ("hipArray_const_t", CONV_MEM, API_RUNTIME)), - ("cudaMipmappedArray_t", ("hipMipmappedArray_t", CONV_MEM, API_RUNTIME)), - ("cudaMipmappedArray_const_t", ("hipMipmappedArray_const_t", CONV_MEM, API_RUNTIME)), - ("cudaArrayDefault", ("hipArrayDefault", CONV_MEM, API_RUNTIME)), - ("cudaArrayLayered", ("hipArrayLayered", CONV_MEM, API_RUNTIME)), - ("cudaArraySurfaceLoadStore", ("hipArraySurfaceLoadStore", CONV_MEM, API_RUNTIME)), - ("cudaArrayCubemap", ("hipArrayCubemap", CONV_MEM, API_RUNTIME)), - ("cudaArrayTextureGather", ("hipArrayTextureGather", CONV_MEM, API_RUNTIME)), - ("cudaMemoryAdvise", ("hipMemAdvise", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaMemRangeAttribute", ("hipMemRangeAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaMemcpyKind", ("hipMemcpyKind", CONV_MEM, API_RUNTIME)), - ("cudaMemoryType", ("hipMemoryType", CONV_MEM, API_RUNTIME)), - ("cudaExtent", ("hipExtent", CONV_MEM, API_RUNTIME)), - ("cudaPitchedPtr", ("hipPitchedPtr", CONV_MEM, API_RUNTIME)), - ("cudaPos", ("hipPos", CONV_MEM, API_RUNTIME)), - ("cudaEvent_t", ("hipEvent_t", CONV_TYPE, API_RUNTIME)), - ("cudaStream_t", ("hipStream_t", CONV_TYPE, API_RUNTIME)), - ("cudaPointerAttributes", ("hipPointerAttribute_t", CONV_TYPE, API_RUNTIME)), - ("cudaDeviceAttr", ("hipDeviceAttribute_t", CONV_TYPE, API_RUNTIME)), - ("cudaDeviceProp", ("hipDeviceProp_t", CONV_TYPE, API_RUNTIME)), - ("cudaDeviceP2PAttr", ("hipDeviceP2PAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaComputeMode", ("hipComputeMode", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaFuncCache", ("hipFuncCache_t", CONV_CACHE, API_RUNTIME)), - ("cudaFuncAttributes", ("hipFuncAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaSharedMemConfig", ("hipSharedMemConfig", CONV_TYPE, API_RUNTIME)), - ("cudaLimit", ("hipLimit_t", CONV_TYPE, API_RUNTIME)), - ("cudaOutputMode", ("hipOutputMode", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaTextureReadMode", ("hipTextureReadMode", CONV_TEX, API_RUNTIME)), - ("cudaTextureFilterMode", ("hipTextureFilterMode", CONV_TEX, API_RUNTIME)), - ("cudaChannelFormatKind", ("hipChannelFormatKind", CONV_TEX, API_RUNTIME)), - ("cudaChannelFormatDesc", ("hipChannelFormatDesc", CONV_TEX, API_RUNTIME)), - ("cudaResourceDesc", ("hipResourceDesc", CONV_TEX, API_RUNTIME)), - ("cudaResourceViewDesc", ("hipResourceViewDesc", CONV_TEX, API_RUNTIME)), - ("cudaTextureDesc", ("hipTextureDesc", CONV_TEX, API_RUNTIME)), - ("surfaceReference", ("hipSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaTextureObject_t", ("hipTextureObject_t", CONV_TEX, API_RUNTIME)), - ("cudaResourceType", ("hipResourceType", CONV_TEX, API_RUNTIME)), - ("cudaResourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_RUNTIME)), - ("cudaTextureAddressMode", ("hipTextureAddressMode", CONV_TEX, API_RUNTIME)), - ("cudaSurfaceBoundaryMode", ("hipSurfaceBoundaryMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaSurfaceFormatMode", ("hipSurfaceFormatMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaTextureType1D", ("hipTextureType1D", CONV_TEX, API_RUNTIME)), - ("cudaTextureType2D", ("hipTextureType2D", CONV_TEX, API_RUNTIME)), - ("cudaTextureType3D", ("hipTextureType3D", CONV_TEX, API_RUNTIME)), - ("cudaTextureTypeCubemap", ("hipTextureTypeCubemap", CONV_TEX, API_RUNTIME)), - ("cudaTextureType1DLayered", ("hipTextureType1DLayered", CONV_TEX, API_RUNTIME)), - ("cudaTextureType2DLayered", ("hipTextureType2DLayered", CONV_TEX, API_RUNTIME)), - ("cudaTextureTypeCubemapLayered", ("hipTextureTypeCubemapLayered", CONV_TEX, API_RUNTIME)), - ("cudaIpcEventHandle_t", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), - ("cudaIpcEventHandle_st", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), - ("cudaIpcMemHandle_t", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), - ("cudaIpcMemHandle_st", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), - ("cudaGraphicsCubeFace", ("hipGraphicsCubeFace", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaGraphicsMapFlags", ("hipGraphicsMapFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaGraphicsRegisterFlags", ("hipGraphicsRegisterFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaGLDeviceList", ("hipGLDeviceList", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaGLMapFlags", ("hipGLMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaD3D9DeviceList", ("hipD3D9DeviceList", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaD3D9MapFlags", ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaD3D9RegisterFlags", ("hipD3D9RegisterFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaD3D10DeviceList", ("hipd3d10DeviceList", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaD3D10MapFlags", ("hipD3D10MapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaD3D10RegisterFlags", ("hipD3D10RegisterFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaD3D11DeviceList", ("hipd3d11DeviceList", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED)), - ("cudaEglStreamConnection", ("hipEglStreamConnection", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED)), - ("cublasHandle_t", ("rocblas_handle", CONV_TYPE, API_BLAS)), - ("cublasOperation_t", ("rocblas_operation", CONV_TYPE, API_BLAS)), - ("cublasStatus_t", ("rocblas_status", CONV_TYPE, API_BLAS)), - ("cublasFillMode_t", ("rocblas_fill", CONV_TYPE, API_BLAS)), - ("cublasDiagType_t", ("rocblas_diagonal", CONV_TYPE, API_BLAS)), - ("cublasSideMode_t", ("rocblas_side", CONV_TYPE, API_BLAS)), - ("cublasPointerMode_t", ("rocblas_pointer_mode", CONV_TYPE, API_BLAS)), - ("cublasAtomicsMode_t", ("rocblas_atomics_mode", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED)), - ("cublasDataType_t", ("rocblas_data_type", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED)), - ("curandStatus", ("hiprandStatus_t", CONV_TYPE, API_RAND)), - ("curandStatus_t", ("hiprandStatus_t", CONV_TYPE, API_RAND)), - ("curandRngType", ("hiprandRngType_t", CONV_TYPE, API_RAND)), - ("curandRngType_t", ("hiprandRngType_t", CONV_TYPE, API_RAND)), - ("curandGenerator_st", ("hiprandGenerator_st", CONV_TYPE, API_RAND)), - ("curandGenerator_t", ("hiprandGenerator_t", CONV_TYPE, API_RAND)), - ("curandDirectionVectorSet", ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDirectionVectorSet_t", ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandOrdering", ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandOrdering_t", ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDistribution_st", ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandHistogramM2V_st", ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDistribution_t", ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandHistogramM2V_t", ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDistributionShift_st", ("hiprandDistributionShift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDistributionShift_t", ("hiprandDistributionShift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDistributionM2Shift_st", ("hiprandDistributionM2Shift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDistributionM2Shift_t", ("hiprandDistributionM2Shift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandHistogramM2_st", ("hiprandHistogramM2_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandHistogramM2_t", ("hiprandHistogramM2_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandHistogramM2K_st", ("hiprandHistogramM2K_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandHistogramM2K_t", ("hiprandHistogramM2K_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDiscreteDistribution_st", ("hiprandDiscreteDistribution_st", CONV_TYPE, API_RAND)), - ("curandDiscreteDistribution_t", ("hiprandDiscreteDistribution_t", CONV_TYPE, API_RAND)), - ("curandMethod", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandMethod_t", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandDirectionVectors32_t", ("hiprandDirectionVectors32_t", CONV_TYPE, API_RAND)), - ("curandDirectionVectors64_t", ("hiprandDirectionVectors64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandStateMtgp32_t", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), - ("curandStateMtgp32", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), - ("curandStateScrambledSobol64_t", ("hiprandStateScrambledSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandStateSobol64_t", ("hiprandStateSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandStateScrambledSobol32_t", ("hiprandStateScrambledSobol32_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), - ("curandStateSobol32_t", ("hiprandStateSobol32_t", CONV_TYPE, API_RAND)), - ("curandStateMRG32k3a_t", ("hiprandStateMRG32k3a_t", CONV_TYPE, API_RAND)), - ("curandStatePhilox4_32_10_t", ("hiprandStatePhilox4_32_10_t", CONV_TYPE, API_RAND)), - ("curandStateXORWOW_t", ("hiprandStateXORWOW_t", CONV_TYPE, API_RAND)), - ("curandState_t", ("hiprandState_t", CONV_TYPE, API_RAND)), - ("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)), -]) - -CUDA_INCLUDE_MAP = collections.OrderedDict([ - # since pytorch uses "\b{pattern}\b" as the actual re pattern, - # patterns listed here have to begin and end with alnum chars - ("include Tensor self: grad * (self <= max).to(grad.dtype()) -- name: clone(Tensor self) -> Tensor +- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor self: grad - name: coalesce(Tensor self) -> Tensor @@ -953,6 +953,10 @@ - name: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor self: binary_cross_entropy_backward(grad, self, target, weight, reduction) +- name: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_double_backward(grad_output, grad, self, target, weight, reduction) + grad_output: binary_cross_entropy_double_backward_grad_output(grad, self, target, weight, reduction) + - name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction) target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 391aafd106fc0..b8991c746139b 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -902,6 +902,42 @@ Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_outp return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad); } +Tensor binary_cross_entropy_double_backward(const Tensor & grad_output, const Tensor & grad, const Tensor & input, const Tensor & target, const Tensor& weight, int64_t reduction) { + auto eps = 1e-12; + auto inp_pl_eps = input + eps; + auto one_m_inp_pl_eps = 1 - input + eps; + // gradient wrt input + auto gI = (input * input - 2 * input * target + target) / (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2)); + gI *= (grad * grad_output); + + if (weight.defined()) { + gI *= weight; + } + if (reduction == Reduction::Mean) { + return gI / input.numel(); + } else if (reduction == Reduction::Sum) { + return gI.sum(); + } + return gI; +} + +Tensor binary_cross_entropy_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, const Tensor& weight, int64_t reduction) { + auto eps = 1e-12; + // gradient wrt grad_output + auto ggO = (input - target) / ((input + eps) * (1 - input + eps)); + ggO *= grad; + + if (weight.defined()) { + ggO *= weight; + } + if (reduction == Reduction::Mean) { + return ggO / input.numel(); + } else if (reduction == Reduction::Sum) { + return ggO.sum(); + } + return ggO; +} + Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) { auto output = l1_loss_backward(grad, input, target, Reduction::None); if (reduction == Reduction::Mean) { diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index a41ff6cb68956..d3ee543b27a6c 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -604,15 +604,6 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_new_zeros(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - auto& self_ = reinterpret_cast(self)->cdata; - OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::new_zeros(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); - END_HANDLE_TH_ERRORS -} - static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS @@ -789,7 +780,6 @@ PyMethodDef variable_methods[] = { {"new", (PyCFunction)(void(*)(void))THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL}, {"new_ones", (PyCFunction)(void(*)(void))THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL}, {"new_tensor", (PyCFunction)(void(*)(void))THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, - {"new_zeros", (PyCFunction)(void(*)(void))THPVariable_new_zeros, METH_VARARGS | METH_KEYWORDS, NULL}, {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS, NULL}, {"numpy", (PyCFunction)THPVariable_numpy, METH_NOARGS, NULL}, {"record_stream", (PyCFunction)THPVariable_record_stream, METH_O, NULL}, diff --git a/tools/build_variables.py b/tools/build_variables.py index 88ffafe3a8e12..c5d47b34ef820 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -52,13 +52,21 @@ "torch/csrc/distributed/autograd/utils.cpp", "torch/csrc/distributed/autograd/context/dist_autograd_container.cpp", "torch/csrc/distributed/autograd/context/dist_autograd_context.cpp", + "torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp", "torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp", "torch/csrc/distributed/rpc/future_message.cpp", "torch/csrc/distributed/rpc/message.cpp", + "torch/csrc/distributed/rpc/python_remote_call.cpp", + "torch/csrc/distributed/rpc/python_udf_call.cpp", + "torch/csrc/distributed/rpc/python_udf_resp.cpp", + "torch/csrc/distributed/rpc/request_callback.cpp", + "torch/csrc/distributed/rpc/rpc_with_autograd.cpp", + "torch/csrc/distributed/rpc/rref_proto.cpp", "torch/csrc/distributed/rpc/script_call.cpp", "torch/csrc/distributed/rpc/script_remote_call.cpp", - "torch/csrc/distributed/rpc/script_rref_proto.cpp", - "torch/csrc/distributed/rpc/script_ret.cpp", + "torch/csrc/distributed/rpc/script_resp.cpp", + "torch/csrc/distributed/rpc/types.cpp", + "torch/csrc/distributed/rpc/utils.cpp", "torch/csrc/Exceptions.cpp", "torch/csrc/jit/autodiff.cpp", "torch/csrc/jit/attributes.cpp", @@ -146,8 +154,6 @@ "torch/csrc/jit/script/builtin_functions.cpp", "torch/csrc/jit/script/module.cpp", "torch/csrc/jit/tracer.cpp", - "torch/csrc/utils/tensor_flatten.cpp", - "torch/csrc/utils/variadic.cpp", "torch/csrc/jit/fuser/kernel_cache.cpp", "torch/csrc/jit/fuser/compiler.cpp", "torch/csrc/jit/fuser/executor.cpp", @@ -162,6 +168,9 @@ "torch/csrc/jit/mobile/module.cpp", "torch/csrc/jit/mobile/register_mobile_ops.cpp", "torch/csrc/jit/mobile/interpreter.cpp", + "torch/csrc/utils/byte_order.cpp", + "torch/csrc/utils/tensor_flatten.cpp", + "torch/csrc/utils/variadic.cpp", ] libtorch_cuda_sources = [ @@ -256,21 +265,19 @@ def add_torch_libs(): "torch/csrc/autograd/python_legacy_variable.cpp", "torch/csrc/autograd/python_variable.cpp", "torch/csrc/autograd/python_variable_indexing.cpp", - "torch/csrc/byte_order.cpp", "torch/csrc/distributed/autograd/init.cpp", "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/reducer.cpp", - "torch/csrc/distributed/autograd/init.cpp", - "torch/csrc/distributed/rpc/functions.cpp", "torch/csrc/distributed/rpc/init.cpp", "torch/csrc/distributed/rpc/process_group_agent.cpp", + "torch/csrc/distributed/rpc/py_rref.cpp", "torch/csrc/distributed/rpc/python_functions.cpp", "torch/csrc/distributed/rpc/python_rpc_handler.cpp", + "torch/csrc/distributed/rpc/request_callback_impl.cpp", "torch/csrc/distributed/rpc/rpc_agent.cpp", "torch/csrc/distributed/rpc/rref.cpp", "torch/csrc/distributed/rpc/rref_context.cpp", - "torch/csrc/distributed/rpc/types.cpp", "torch/csrc/jit/init.cpp", "torch/csrc/jit/passes/inline_fork_wait.cpp", "torch/csrc/jit/passes/onnx.cpp", diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index f3bd37cf1b069..563c8ec8379ee 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -160,9 +160,15 @@ def from_ivalue(arg, value): auto result_ = (${first}).${name}(${args_with_tensor_options}); """) +# Adding `AutoNonVariableTypeMode` guard for `USE_STATIC_DISPATCH` case is kinda +# hack to address issue #26764. TODO: remove this hack after Variable/Tensor +# unification (#23032) is done. CONSTRUCTOR = CodeTemplate("""\ [](Stack & stack) { ${lvalues} +#ifdef USE_STATIC_DISPATCH + at::AutoNonVariableTypeMode non_var_type_mode(true); +#endif ${call} drop(stack, ${num_inputs}); pack(stack, std::move(result_)); diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index f7fd5a939d453..794f9202f141c 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -489,8 +489,6 @@ def gen_pyi(declarations_path, out): 'def stride(self, _int) -> _int: ...'], 'new_ones': ['def new_ones(self, size: {}, {}) -> Tensor: ...'. format(type_to_python('IntArrayRef'), FACTORY_PARAMS)], - 'new_zeros': ['def new_zeros(self, size: {}, {}) -> Tensor: ...'. - format(type_to_python('IntArrayRef'), FACTORY_PARAMS)], 'new_tensor': ["def new_tensor(self, data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)], # clamp has no default values in the Declarations 'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf," diff --git a/tools/setup_helpers/cuda.py b/tools/setup_helpers/cuda.py index 25d5e3e46b339..06f549d3d00b2 100644 --- a/tools/setup_helpers/cuda.py +++ b/tools/setup_helpers/cuda.py @@ -1,11 +1,9 @@ import os import glob -import re import ctypes.util -from subprocess import Popen, PIPE from . import which -from .env import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_env_flag, check_negative_env_flag +from .env import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_negative_env_flag LINUX_HOME = '/usr/local/cuda' WINDOWS_HOME = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') @@ -19,44 +17,9 @@ def find_nvcc(): return None -def find_cuda_version(cuda_home): - if cuda_home is None: - return None - if IS_WINDOWS: - candidate_names = [os.path.basename(cuda_home)] - else: - # get CUDA lib folder - cuda_lib_dirs = ['lib64', 'lib'] - for lib_dir in cuda_lib_dirs: - cuda_lib_path = os.path.join(cuda_home, lib_dir) - if os.path.exists(cuda_lib_path): - break - # get a list of candidates for the version number - # which are files containing cudart - candidate_names = list(glob.glob(os.path.join(cuda_lib_path, '*cudart*'))) - candidate_names = [os.path.basename(c) for c in candidate_names] - # if we didn't find any cudart, ask nvcc - if len(candidate_names) == 0: - proc = Popen(['nvcc', '--version'], stdout=PIPE, stderr=PIPE) - out, err = proc.communicate() - candidate_names = [out.decode().rsplit('V')[-1]] - - # suppose version is MAJOR.MINOR.PATCH, all numbers - version_regex = re.compile(r'[0-9]+\.[0-9]+\.[0-9]+') - candidates = [c.group() for c in map(version_regex.search, candidate_names) if c] - if len(candidates) > 0: - # normally only one will be retrieved, take the first result - return candidates[0] - # if no candidates were found, try MAJOR.MINOR - version_regex = re.compile(r'[0-9]+\.[0-9]+') - candidates = [c.group() for c in map(version_regex.search, candidate_names) if c] - if len(candidates) > 0: - return candidates[0] - -if check_negative_env_flag('USE_CUDA') or check_env_flag('USE_ROCM'): +if check_negative_env_flag('USE_CUDA'): USE_CUDA = False CUDA_HOME = None - CUDA_VERSION = None else: if IS_LINUX or IS_DARWIN: CUDA_HOME = os.getenv('CUDA_HOME', LINUX_HOME) @@ -78,5 +41,4 @@ def find_cuda_version(cuda_home): CUDA_HOME = os.path.dirname(cuda_path) else: CUDA_HOME = None - CUDA_VERSION = find_cuda_version(CUDA_HOME) USE_CUDA = CUDA_HOME is not None diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 59002e7203df6..7372180d24d80 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -68,7 +68,6 @@ set(TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/autograd/python_legacy_variable.cpp ${TORCH_SRC_DIR}/csrc/autograd/python_variable.cpp ${TORCH_SRC_DIR}/csrc/autograd/python_variable_indexing.cpp - ${TORCH_SRC_DIR}/csrc/byte_order.cpp ${TORCH_SRC_DIR}/csrc/jit/init.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp @@ -225,20 +224,18 @@ if (USE_DISTRIBUTED) if (NOT MSVC) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp - ${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp - ${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_context.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/comm.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/init.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/reducer.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/init.cpp - ${TORCH_SRC_DIR}/csrc/distributed/rpc/functions.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/process_group_agent.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/py_rref.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref.cpp - ${TORCH_SRC_DIR}/csrc/distributed/rpc/types.cpp ) list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) @@ -250,12 +247,8 @@ endif() if (USE_NCCL) list(APPEND TORCH_PYTHON_SRCS - ${TORCH_SRC_DIR}/csrc/cuda/nccl.cpp ${TORCH_SRC_DIR}/csrc/cuda/python_nccl.cpp) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_NCCL) - list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) - if (USE_SYSTEM_NCCL) - endif() endif() # In the most recent CMake versions, a new 'TRANSFORM' subcommand of 'list' allows much of the boilerplate of defining the lists diff --git a/torch/__init__.py b/torch/__init__.py index 88130b4cc3573..a36171d6d2e04 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -303,7 +303,7 @@ def manager_path(): import torch.autograd from torch.autograd import no_grad, enable_grad, set_grad_enabled # noqa: F401 import torch.nn -import torch.nn._intrinsic +import torch.nn.intrinsic import torch.nn.quantized import torch.optim import torch.multiprocessing diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 8a8f6fed5f403..90359cef759e8 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -630,7 +630,7 @@ def __getitem__(self, types): # Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. def _qualified_name(obj): # short-circuit in cases where the object already has a known qualified name - if isinstance(obj, torch._C.Function): + if isinstance(obj, torch.jit.ScriptFunction): return obj.qualified_name name = obj.__name__ diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 9d50a4112a049..e58eb37725b3c 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2253,10 +2253,6 @@ def callable(a, b) -> number dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``. -Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be -between ``0`` and ``self.size(dim) - 1`` inclusive, and all values in a row along -the specified dimension :attr:`dim` must be unique. - .. include:: cuda_deterministic.rst Args: diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index c1e99bd22d6f1..6cf93b949c8c0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -899,12 +899,12 @@ def merge_dicts(*dicts): Computes the p-norm distance between each pair of the two collections of row vectors. -If x1 has shape :math:`P \times M` and x2 has shape :math:`R \times M` then the +If x1 has shape :math:`P \times M` and x2 has shape :math:`R \times M` then the output will have shape :math:`P \times R`. -This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)` -if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to -`scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest +This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)` +if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to +`scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`. Args: @@ -1055,11 +1055,6 @@ def merge_dicts(*dicts): batches of 2D matrices. If the inputs are batches, then returns batched outputs `c` -.. note:: - - The :attr:`out` keyword only supports 2D matrix inputs, that is, - `b, u` must be 2D matrices. - Args: input (Tensor): input matrix :math:`b` of size :math:`(*, m, k)`, where :math:`*` is zero or more batch dimensions @@ -4569,7 +4564,7 @@ def merge_dicts(*dicts): r""" set_num_threads(int) -Sets the number of threads used for parallelizing CPU operations. +Sets the number of threads used for intraop parallelism on CPU. WARNING: To ensure that the correct number of threads is used, set_num_threads must be called before running eager, JIT or autograd code. @@ -5486,11 +5481,6 @@ def merge_dicts(*dicts): batches of 2D matrices. If the inputs are batches, then returns batched outputs `X` -.. note:: - - The :attr:`out` keyword only supports 2D matrix inputs, that is, - `b, A` must be 2D matrices. - Args: input (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where :math:`*` is zero of more batch dimensions (:math:`b`) diff --git a/torch/_utils.py b/torch/_utils.py index d3748842aefbe..d8c40b140075b 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -140,11 +140,20 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac tensor._backward_hooks = backward_hooks return tensor + +def _rebuild_sparse_tensor(layout, data): + if layout == torch.sparse_coo: + indices, values, size = data + return torch.sparse_coo_tensor(indices, values, size) + raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout)) + + def _rebuild_xla_tensor(data, dtype, device, requires_grad): tensor = torch.from_numpy(data).to(dtype=dtype, device=device) tensor.requires_grad = requires_grad return tensor + def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks): qscheme = quantizer_params[0] if qscheme == torch.per_tensor_affine: diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 0f4ae714380ec..ea1a797c12c8b 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -613,6 +613,9 @@ def add(self, other, group_by_input_shapes=False): self.count += 1 return self + def __iadd__(self, other): + return self.add(other) + def __repr__(self): return ( ' #include #include -#include -#include #include +#include +#include +#include #include #include // This requires defined Storage and Tensor types -#include +#include #ifdef _THP_CORE #include diff --git a/torch/csrc/api/include/torch/nn.h b/torch/csrc/api/include/torch/nn.h index 10474512a2580..b93220b5d62a0 100644 --- a/torch/csrc/api/include/torch/nn.h +++ b/torch/csrc/api/include/torch/nn.h @@ -7,3 +7,4 @@ #include #include #include +#include diff --git a/torch/csrc/api/include/torch/nn/functional.h b/torch/csrc/api/include/torch/nn/functional.h index 062db965fb01e..39b3eed642736 100644 --- a/torch/csrc/api/include/torch/nn/functional.h +++ b/torch/csrc/api/include/torch/nn/functional.h @@ -3,3 +3,4 @@ #include #include #include +#include diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index d9688b5b3a88d..4fbebed2bef49 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -20,6 +20,26 @@ inline Tensor hardshrink(const Tensor& input, return torch::hardshrink(input, options.lambda()); } +inline Tensor hardtanh(Tensor& input, const HardtanhOptions& options) { + if (options.inplace()) { + return torch::hardtanh_(input, options.min_val(), options.max_val()); + } else { + return torch::hardtanh(input, options.min_val(), options.max_val()); + } +} + +inline Tensor leaky_relu(Tensor& input, const LeakyReLUOptions& options) { + if (options.inplace()) { + return torch::leaky_relu_(input, options.negative_slope()); + } else { + return torch::leaky_relu(input, options.negative_slope()); + } +} + +inline Tensor logsigmoid(const Tensor& input) { + return torch::log_sigmoid(input); +} + } // namespace functional } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/functional/embedding.h b/torch/csrc/api/include/torch/nn/functional/embedding.h new file mode 100644 index 0000000000000..7003239f1ac4c --- /dev/null +++ b/torch/csrc/api/include/torch/nn/functional/embedding.h @@ -0,0 +1,12 @@ +#pragma once + +namespace torch { +namespace nn { +namespace functional { + +inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) { + return torch::one_hot(tensor, num_classes); +} +} // namespace functional +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h index bc7c1c3b28a5b..748980cbb46da 100644 --- a/torch/csrc/api/include/torch/nn/modules/activation.h +++ b/torch/csrc/api/include/torch/nn/modules/activation.h @@ -55,5 +55,70 @@ class TORCH_API HardshrinkImpl : public torch::nn::Cloneable { TORCH_MODULE(Hardshrink); +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the HardTanh function element-wise. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.Hardtanh to learn +/// about the exact behavior of this module. +class TORCH_API HardtanhImpl : public torch::nn::Cloneable { + public: + HardtanhImpl() : HardtanhImpl(HardtanhOptions()) {} + explicit HardtanhImpl(const HardtanhOptions& options_); + + Tensor forward(Tensor& input); + + void reset() override; + + /// Pretty prints the `Hardtanh` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + HardtanhOptions options; +}; + +TORCH_MODULE(Hardtanh); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LeakyReLU function element-wise. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.LeakyReLU to learn +/// about the exact behavior of this module. +class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable { + public: + LeakyReLUImpl() : LeakyReLUImpl(LeakyReLUOptions()) {} + explicit LeakyReLUImpl(const LeakyReLUOptions& options_); + + Tensor forward(Tensor& input); + + void reset() override; + + /// Pretty prints the `LeakyReLU` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + LeakyReLUOptions options; +}; + +TORCH_MODULE(LeakyReLU); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies the LogSigmoid function element-wise. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.LogSigmoid to learn +/// about the exact behavior of this module. +class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable { + public: + LogSigmoidImpl() {} + + Tensor forward(const Tensor& input); + + void reset() override; + + /// Pretty prints the `LogSigmoid` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; +}; + +TORCH_MODULE(LogSigmoid); + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/activation.h b/torch/csrc/api/include/torch/nn/options/activation.h index 8ec2eafe9cf93..592f3c61e1c43 100644 --- a/torch/csrc/api/include/torch/nn/options/activation.h +++ b/torch/csrc/api/include/torch/nn/options/activation.h @@ -27,5 +27,34 @@ struct TORCH_API HardshrinkOptions { TORCH_ARG(double, lambda); }; +// ============================================================================ + +/// Options for Hardtanh functional and module. +struct HardtanhOptions { + HardtanhOptions() {} + + /// minimum value of the linear region range. Default: -1 + TORCH_ARG(double, min_val) = -1.0; + + /// maximum value of the linear region range. Default: 1 + TORCH_ARG(double, max_val) = 1.0; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + +// ============================================================================ + +/// Options for LeakyReLU functional and module. +struct LeakyReLUOptions { + LeakyReLUOptions() {} + + /// Controls the angle of the negative slope. Default: 1e-2 + TORCH_ARG(double, negative_slope) = 1e-2; + + /// can optionally do the operation in-place. Default: False + TORCH_ARG(bool, inplace) = false; +}; + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/utils.h b/torch/csrc/api/include/torch/nn/utils.h new file mode 100644 index 0000000000000..550949ef08666 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/utils.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/torch/csrc/api/include/torch/nn/utils/clip_grad.h new file mode 100644 index 0000000000000..fea0fe0a9c9f4 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/utils/clip_grad.h @@ -0,0 +1,62 @@ +#pragma once + +#include + +namespace torch { +namespace nn { +namespace utils { + +// Clips gradient norm of a vector of Tensors. +// See +// https://pytorch.org/docs/stable/nn.html?highlight=clip_grad_norm#torch.nn.utils.clip_grad_norm_ +// for more details about this module. +inline float clip_grad_norm_( + std::vector& parameters, + float max_norm, + float norm_type = 2.0) { + std::vector params_with_grad; + + for (const auto& param : parameters) { + auto& grad = param.grad(); + if (grad.defined()) { + params_with_grad.push_back(param); + } + } + float total_norm = 0.0; + if (norm_type == std::numeric_limits::infinity()) { + for (const auto& param : params_with_grad) { + auto param_max = param.grad().data().abs().max().item().toFloat(); + if (param_max > total_norm) { + total_norm = param_max; + } + } + } else { + for (const auto& param : params_with_grad) { + auto param_norm = param.grad().data().norm(norm_type); + total_norm += std::pow(param_norm.item().toFloat(), norm_type); + } + total_norm = std::pow(total_norm, 1.0 / norm_type); + } + + auto clip_coef = max_norm / (total_norm + 1e-6); + if (clip_coef < 1) { + for (auto& param : params_with_grad) { + param.grad().data().mul_(clip_coef); + } + } + return total_norm; +} + +// A wrapper around clip_grad_norm_ that allows us to call the function with a +// single Tensor. +inline float clip_grad_norm_( + Tensor& parameters, + float max_norm, + float norm_type = 2.0) { + std::vector params = {parameters}; + return clip_grad_norm_(params, max_norm, norm_type); +} + +} // namespace utils +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index e06bfad248ede..ddbb493896d95 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -38,5 +38,63 @@ void HardshrinkImpl::pretty_print(std::ostream& stream) const { << "torch::nn::Hardshrink(" << options.lambda() << ")"; } +// ============================================================================ + +HardtanhImpl::HardtanhImpl(const HardtanhOptions& options_) + : options(options_) { + reset(); +} + +Tensor HardtanhImpl::forward(Tensor& input) { + return F::hardtanh(input, options); +} + +void HardtanhImpl::reset() { + TORCH_CHECK(options.max_val() > options.min_val(), + "max_val must be greater than min_val"); +} + +void HardtanhImpl::pretty_print(std::ostream& stream) const { + stream << std::boolalpha + << "torch::nn::Hardtanh(min_val=" << options.min_val() + << ", max_val=" << options.max_val(); + if (options.inplace()) { + stream << std::boolalpha << ", inplace=" << options.inplace(); + } + stream << ")"; +} + +// ============================================================================ + +LeakyReLUImpl::LeakyReLUImpl(const LeakyReLUOptions& options_) + : options(options_) {} + +Tensor LeakyReLUImpl::forward(Tensor& input) { + return F::leaky_relu(input, options); +} + +void LeakyReLUImpl::reset() {} + +void LeakyReLUImpl::pretty_print(std::ostream& stream) const { + stream << std::boolalpha + << "torch::nn::LeakyReLU(negative_slope=" << options.negative_slope(); + if (options.inplace()) { + stream << std::boolalpha << ", inplace=" << options.inplace(); + } + stream << ")"; +} + +// ============================================================================ + +Tensor LogSigmoidImpl::forward(const Tensor& input) { + return F::logsigmoid(input); +} + +void LogSigmoidImpl::reset() {} + +void LogSigmoidImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::LogSigmoid()"; +} + } // namespace nn } // namespace torch diff --git a/torch/csrc/autograd/profiler_cuda.cpp b/torch/csrc/autograd/profiler_cuda.cpp index ae3b1bbc86a56..b03d574431877 100644 --- a/torch/csrc/autograd/profiler_cuda.cpp +++ b/torch/csrc/autograd/profiler_cuda.cpp @@ -1,8 +1,6 @@ #include #include -#ifndef __HIP_PLATFORM_HCC__ #include -#endif #include @@ -35,19 +33,13 @@ struct CUDAMethods : public CUDAStubs { return ms*1000.0; } void nvtxMarkA(const char* name) override { -#ifndef __HIP_PLATFORM_HCC__ ::nvtxMark(name); -#endif } void nvtxRangePushA(const char* name) override { -#ifndef __HIP_PLATFORM_HCC__ ::nvtxRangePushA(name); -#endif } void nvtxRangePop() override { -#ifndef __HIP_PLATFORM_HCC__ ::nvtxRangePop(); -#endif } void onEachDevice(std::function op) override { at::cuda::OptionalCUDAGuard device_guard; diff --git a/torch/csrc/byte_order.h b/torch/csrc/byte_order.h deleted file mode 100644 index b8b0f7a22c9a7..0000000000000 --- a/torch/csrc/byte_order.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef THP_BYTE_ORDER_H -#define THP_BYTE_ORDER_H - -#include -#include -#include -#include - -enum THPByteOrder { - THP_LITTLE_ENDIAN = 0, - THP_BIG_ENDIAN = 1 -}; - -THPByteOrder THP_nativeByteOrder(); - -void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len); -void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len); -void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len); -void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, size_t len); -void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len); -void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len); -void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len); -void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len); - -void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len); -void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len); -void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order, size_t len); -void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, size_t len); -void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order, size_t len); - -#endif diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 93c0aa4344dd0..f421d1c6fe73d 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -70,7 +70,7 @@ using device_list = std::vector; static std::unordered_map> _communicators; -ArrayRef _get_communicators(TensorList inputs) { +ArrayRef get_communicators(TensorList inputs) { static auto get_device = [](const at::Tensor& t) -> int { return t.get_device(); }; @@ -81,7 +81,7 @@ ArrayRef _get_communicators(TensorList inputs) { return it->second.ref(); } -ncclDataType_t _get_data_type(const Tensor& t) { +ncclDataType_t get_data_type(const Tensor& t) { if (t.type().backend() != Backend::CUDA) { throw std::runtime_error("Unconvertible NCCL type"); } @@ -105,7 +105,7 @@ ncclDataType_t _get_data_type(const Tensor& t) { } } -void _check_inputs( +void check_inputs( TensorList inputs, TensorList outputs, int input_multiplier, @@ -232,12 +232,12 @@ void broadcast( const comm_list& user_comms) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; - _check_inputs(tensors, tensors, 1, 1); - ncclDataType_t data_type = _get_data_type(tensors[0]); + check_inputs(tensors, tensors, 1, 1); + ncclDataType_t data_type = get_data_type(tensors[0]); int64_t numel = tensors[0].numel(); AutoNcclGroup nccl_group_guard; - const auto comms = user_comms.empty() ? _get_communicators(tensors) + const auto comms = user_comms.empty() ? get_communicators(tensors) : ArrayRef(user_comms); at::cuda::OptionalCUDAGuard device_guard; @@ -276,14 +276,14 @@ void reduce( TORCH_CHECK( root >= 0 && static_cast(root) < inputs.size(), "invalid root"); - _check_inputs(inputs, outputs, 1, 1); + check_inputs(inputs, outputs, 1, 1); const auto len = inputs.size(); - ncclDataType_t data_type = _get_data_type(inputs[0]); + ncclDataType_t data_type = get_data_type(inputs[0]); const auto count = inputs[0].numel(); AutoNcclGroup nccl_group_guard; - auto comms_ref = user_comms.empty() ? _get_communicators(inputs) + auto comms_ref = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); at::cuda::OptionalCUDAGuard device_guard; diff --git a/torch/csrc/cuda/nccl.h b/torch/csrc/cuda/nccl.h index 66aa4aecec6f5..1e92850043dcc 100644 --- a/torch/csrc/cuda/nccl.h +++ b/torch/csrc/cuda/nccl.h @@ -42,31 +42,31 @@ struct AutoNcclGroup { } }; -at::ArrayRef _get_communicators(at::TensorList inputs); -void _check_inputs( +TORCH_API at::ArrayRef get_communicators(at::TensorList inputs); +TORCH_API void check_inputs( at::TensorList inputs, at::TensorList outputs, int input_multiplier, int output_multiplier); -ncclDataType_t _get_data_type(const at::Tensor& t); +TORCH_API ncclDataType_t get_data_type(const at::Tensor& t); } // namespace detail using comm_list = std::vector; using stream_list = std::vector>; -std::uint64_t version(); +TORCH_API std::uint64_t version(); bool is_available(at::TensorList tensors); -void broadcast( +TORCH_API void broadcast( at::TensorList tensors, const stream_list& streams = {}, const comm_list& user_comms = {}); size_t get_max_count(); -void reduce( +TORCH_API void reduce( const std::vector& inputs, std::vector& outputs, int32_t root = 0, @@ -74,7 +74,7 @@ void reduce( const stream_list& streams = {}, const comm_list& user_comms = {}); -void reduce( +TORCH_API void reduce( std::vector& inputs, int32_t root = 0, int32_t op = ncclSum, diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 092f3c0afe157..3601134322f39 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -185,14 +185,14 @@ PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) { auto user_comms = unpack_comms(_comms, inputs.size()); with_no_gil([&] { - _check_inputs(inputs, outputs, 1, 1); + check_inputs(inputs, outputs, 1, 1); size_t len = inputs.size(); - ncclDataType_t data_type = _get_data_type(inputs[0]); + ncclDataType_t data_type = get_data_type(inputs[0]); int64_t count = inputs[0].numel(); AutoNcclGroup nccl_group_guard; - auto comms = user_comms.empty() ? _get_communicators(inputs) + auto comms = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < len; i++) { @@ -265,13 +265,13 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) { with_no_gil([&] { size_t len = inputs.size(); - _check_inputs(inputs, outputs, len, 1); + check_inputs(inputs, outputs, len, 1); - ncclDataType_t data_type = _get_data_type(inputs[0]); + ncclDataType_t data_type = get_data_type(inputs[0]); int64_t count = inputs[0].numel(); AutoNcclGroup nccl_group_guard; - auto comms = user_comms.empty() ? _get_communicators(inputs) + auto comms = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < len; i++) { @@ -327,13 +327,13 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) { with_no_gil([&] { size_t len = inputs.size(); - _check_inputs(inputs, outputs, 1, len); + check_inputs(inputs, outputs, 1, len); - ncclDataType_t data_type = _get_data_type(inputs[0]); + ncclDataType_t data_type = get_data_type(inputs[0]); int64_t count = inputs[0].numel() / len; AutoNcclGroup nccl_group_guard; - auto comms = user_comms.empty() ? _get_communicators(inputs) + auto comms = user_comms.empty() ? get_communicators(inputs) : ArrayRef(user_comms); at::cuda::OptionalCUDAGuard device_guard; for (size_t i = 0; i < len; i++) { diff --git a/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp b/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp index be7759c963ddd..3302e1d7b81f9 100644 --- a/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp +++ b/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp @@ -5,56 +5,98 @@ namespace torch { namespace distributed { namespace autograd { -constexpr int kContextIdBits = 48; -constexpr int64_t kContextIdMask = (1LL << kContextIdBits) - 1; +constexpr int kAutoIncrementBits = 48; +constexpr int64_t kAutoIncrementMask = (1LL << kAutoIncrementBits) - 1; constexpr int kMaxWorkerId = 65535; -constexpr int64_t kMaxContextId = kContextIdMask; -thread_local int64_t DistAutogradContainer::current_context_id_ = -1; +// Each thread has a single autograd_context_id valid at any point in time. +static thread_local int64_t current_context_id_ = -1; + +// Lock to ensure DistAutogradContainer is initialized only once. +static std::mutex dist_container_init_lock_; DistAutogradContainer::DistAutogradContainer() - : next_context_id_(0), worker_id_(0), initialized_(false) {} + : next_context_id_(0), + worker_id_(0), + initialized_(false), + next_autograd_message_id_(0), + max_id_(0) {} DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) { + std::lock_guard guard(dist_container_init_lock_); + TORCH_CHECK( worker_id >= 0 && worker_id <= kMaxWorkerId, "worker_id needs to be in the range [0, 65535]") - auto& container = getInstance(); + auto& container = getInstanceInternal(); + TORCH_CHECK( + !container.initialized_, + "Container is already initialized! Cannot initialize it twice!"); + container.worker_id_ = worker_id; container.next_context_id_ = static_cast(worker_id) - << kContextIdBits; + << kAutoIncrementBits; + container.next_autograd_message_id_ = static_cast(worker_id) + << kAutoIncrementBits; + container.max_id_ = + (kAutoIncrementMask | + (static_cast(worker_id) << kAutoIncrementBits)); container.initialized_ = true; return container; } DistAutogradContainer& DistAutogradContainer::getInstance() { + auto& instance = getInstanceInternal(); + TORCH_CHECK( + instance.initialized_, + "Need to initialize distributed autograd using " + "torch.distributed.autograd.init()"); + return instance; +} + +DistAutogradContainer& DistAutogradContainer::getInstanceInternal() { static DistAutogradContainer container; return container; } -const DistAutogradContext& DistAutogradContainer::newContext() { - if (!initialized_) { - throw std::runtime_error( - "Need to initialize distributed autograd using " - "torch.distributed.autograd.init()"); +int64_t DistAutogradContainer::newAutogradMessageId() { + // Check for overflow into workerId_ section. + TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_); + return next_autograd_message_id_++; +} + +DistAutogradContext& DistAutogradContainer::getOrCreateContext( + int64_t context_id) { + std::lock_guard guard(autograd_context_lock_); + auto it = autograd_context_.find(context_id); + if (it != autograd_context_.end()) { + return it->second; } + auto& context = autograd_context_ + .emplace( + std::piecewise_construct, + std::forward_as_tuple(context_id), + std::forward_as_tuple(context_id)) + .first->second; + return context; +} + +const DistAutogradContext& DistAutogradContainer::newContext() { std::lock_guard guard(autograd_context_lock_); - TORCH_INTERNAL_ASSERT( - next_context_id_ < std::numeric_limits::max() && - next_context_id_ < - (kMaxContextId | - (static_cast(worker_id_) << kContextIdBits)), - "We have run out of autograd context ids!!!"); - - autograd_context_.emplace( - std::piecewise_construct, - std::forward_as_tuple(next_context_id_), - std::forward_as_tuple(next_context_id_)); - - current_context_id_ = next_context_id_; - return autograd_context_.at(next_context_id_++); + // Check for overflow into workerId_ section. + TORCH_INTERNAL_ASSERT(next_context_id_ < max_id_); + + auto& context = autograd_context_ + .emplace( + std::piecewise_construct, + std::forward_as_tuple(next_context_id_), + std::forward_as_tuple(next_context_id_)) + .first->second; + + current_context_id_ = next_context_id_++; + return context; } bool DistAutogradContainer::hasValidContext() const { @@ -90,8 +132,8 @@ void DistAutogradContainer::releaseContext(int64_t context_id) { } } -const DistAutogradContext& DistAutogradContainer::retrieveContext( - int64_t context_id) const { +DistAutogradContext& DistAutogradContainer::retrieveContext( + int64_t context_id) { std::lock_guard guard(autograd_context_lock_); TORCH_CHECK( autograd_context_.find(context_id) != autograd_context_.end(), @@ -100,6 +142,10 @@ const DistAutogradContext& DistAutogradContainer::retrieveContext( return autograd_context_.at(context_id); } +int64_t DistAutogradContainer::getMaxId() { + return max_id_; +} + } // namespace autograd } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/autograd/context/dist_autograd_container.h b/torch/csrc/distributed/autograd/context/dist_autograd_container.h index 22ddff5592172..3f0aa846b0fab 100644 --- a/torch/csrc/distributed/autograd/context/dist_autograd_container.h +++ b/torch/csrc/distributed/autograd/context/dist_autograd_container.h @@ -18,12 +18,18 @@ namespace autograd { // autograd_context_id. The autograd_context_id itself is a 64 bit globally // unique id. The first 16 bits is the worker_id and the next 48 bits is an // auto-incrementing id for each worker. -class DistAutogradContainer { +// +// This container is also responsible for maintaining a globally unique message +// id, which is used to associate send/recv autograd function pairs. The format +// is similar to the autograd_context_id where we have a 64 bit integer with +// first 16 bits being the worker id and next 48 bits are auto-incrementing. +class TORCH_API DistAutogradContainer { public: // One time initialization of the container. static DistAutogradContainer& init(int64_t worker_id); - // Retrieve the singleton instance of the container. + // Retrieve the singleton instance of the container, ensures we have + // initialized the container. static DistAutogradContainer& getInstance(); // Create a new context for a distributed autograd pass. @@ -33,7 +39,7 @@ class DistAutogradContainer { void releaseContext(int64_t context_id); // Retrieve the autograd context for a given context_id. - const DistAutogradContext& retrieveContext(int64_t context_id) const; + DistAutogradContext& retrieveContext(int64_t context_id); // Retrieves the currently active autograd context for the current thread. DistAutogradContext& currentContext(); @@ -41,6 +47,18 @@ class DistAutogradContainer { // Checks whether or not the current thread has a valid autograd context. bool hasValidContext() const; + // Generate a new autograd_message_id for send/recv autograd functions. + int64_t newAutogradMessageId(); + + // Creates a new autograd context with the provided context_id. If a context + // already exists with the provided context_id, we just return it. + // This does not set the current context for the current thread. + DistAutogradContext& getOrCreateContext(int64_t context_id); + + // Retrieves the maximum possible autograd_context_id/autograd_message_id that + // can be generated by this worker. + int64_t getMaxId(); + private: DistAutogradContainer(); ~DistAutogradContainer() = default; @@ -50,6 +68,8 @@ class DistAutogradContainer { DistAutogradContainer(DistAutogradContainer&&) = delete; DistAutogradContainer& operator=(DistAutogradContainer&&) = delete; + static DistAutogradContainer& getInstanceInternal(); + // Auto incrementing context id used to identify unique autograd passes. // Initialized with the first 16 bits being the worker_id. int64_t next_context_id_; @@ -67,8 +87,11 @@ class DistAutogradContainer { // and worker_id_ are immutable. mutable std::mutex autograd_context_lock_; - // Each thread has a single autograd_context_id valid at any point in time. - static thread_local int64_t current_context_id_; + // Autograd message id to identify unique send/recv autograd function pairs. + std::atomic next_autograd_message_id_; + + // Maximum allowed value for autograd_context_id or autograd_message_id. + int64_t max_id_; }; } // namespace autograd diff --git a/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp b/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp index 8ad60c7b59ed3..0323ad849aa2e 100644 --- a/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp +++ b/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp @@ -13,16 +13,41 @@ int64_t DistAutogradContext::context_id() const { } void DistAutogradContext::addSendFunction( - const std::shared_ptr& func) { + const std::shared_ptr& func, + int64_t autograd_message_id) { + TORCH_INTERNAL_ASSERT(func != nullptr); + + std::lock_guard guard(lock_); + TORCH_INTERNAL_ASSERT( + sendAutogradFunctions_.find(autograd_message_id) == + sendAutogradFunctions_.end()); + sendAutogradFunctions_.emplace(autograd_message_id, func); +} + +void DistAutogradContext::addRecvFunction( + std::shared_ptr& func, + int64_t autograd_message_id) { + TORCH_INTERNAL_ASSERT(func != nullptr); + std::lock_guard guard(lock_); - sendAutogradFunctions_.push_back(func); + TORCH_INTERNAL_ASSERT( + recvAutogradFunctions_.find(autograd_message_id) == + recvAutogradFunctions_.end()); + recvAutogradFunctions_.emplace(autograd_message_id, func); } -std::vector> DistAutogradContext:: - sendFunctions() const { +std::unordered_map> +DistAutogradContext::sendFunctions() const { + std::lock_guard guard(lock_); return sendAutogradFunctions_; } +std::unordered_map> +DistAutogradContext::recvFunctions() const { + std::lock_guard guard(lock_); + return recvAutogradFunctions_; +} + } // namespace autograd } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/autograd/context/dist_autograd_context.h b/torch/csrc/distributed/autograd/context/dist_autograd_context.h index c4eec7e96d38c..1ebd91e9f57ef 100644 --- a/torch/csrc/distributed/autograd/context/dist_autograd_context.h +++ b/torch/csrc/distributed/autograd/context/dist_autograd_context.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -9,17 +10,30 @@ namespace autograd { // DistAutogradContext which stores information for a single distributed // autograd pass on a worker. -class DistAutogradContext { +class TORCH_API DistAutogradContext { public: explicit DistAutogradContext(int64_t context_id); // Retrieves the autograd context id for this context. int64_t context_id() const; - // Records a 'send' autograd function for this context. - void addSendFunction(const std::shared_ptr& func); + // Records a 'send' autograd function for this context with the provided + // message id. + void addSendFunction( + const std::shared_ptr& func, + int64_t autograd_message_id); - std::vector> sendFunctions() const; + // Records a 'recv' autograd function for this context with the provided + // message id. + void addRecvFunction( + std::shared_ptr& func, + int64_t autograd_message_id); + + std::unordered_map> sendFunctions() + const; + + std::unordered_map> recvFunctions() + const; DistAutogradContext(const DistAutogradContext&) = delete; DistAutogradContext& operator=(const DistAutogradContext&) = delete; @@ -29,7 +43,13 @@ class DistAutogradContext { private: const int64_t context_id_; - std::vector> sendAutogradFunctions_; + // Map from autograd_message_id to appropriate 'send' autograd function. + std::unordered_map> + sendAutogradFunctions_; + + // Map from autograd_message_id to appropriate 'recv' autograd function. + std::unordered_map> + recvAutogradFunctions_; // Lock to protect concurrent modification of the context. mutable std::mutex lock_; diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp new file mode 100644 index 0000000000000..aeecbb4a8906e --- /dev/null +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp @@ -0,0 +1,26 @@ +#include +#include + +namespace torch { +namespace distributed { +namespace autograd { + +using torch::autograd::Variable; + +torch::autograd::variable_list RecvRpcBackward::apply( + torch::autograd::variable_list&& grads) { + std::vector outputGrads; + for (const auto& grad : grads) { + if (grad.defined()) { + outputGrads.emplace_back(grad); + } else { + outputGrads.emplace_back(at::zeros_like(grad)); + } + } + + return outputGrads; +} + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h new file mode 100644 index 0000000000000..fbb2e9ff3ee3c --- /dev/null +++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace torch { +namespace distributed { +namespace autograd { + +// As part of our distributed autograd implementation, whenever we receive an +// RPC from a node, we add a 'RecvRpcBackward' autograd function to the +// autograd graph. This is more or less a placeholder function that is used to +// pass gradients to the remote host during the backward pass. The inputs to the +// RPC function are the inputs to this autograd function. +struct TORCH_API RecvRpcBackward : public torch::autograd::Node { + torch::autograd::variable_list apply( + torch::autograd::variable_list&& grads) override; +}; + +} // namespace autograd +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/autograd/init.cpp b/torch/csrc/distributed/autograd/init.cpp index 021c402943119..fa214df06a4b0 100644 --- a/torch/csrc/distributed/autograd/init.cpp +++ b/torch/csrc/distributed/autograd/init.cpp @@ -30,11 +30,26 @@ PyObject* dist_autograd_init(PyObject* /* unused */) { "_context_id", &DistAutogradContext::context_id, py::call_guard()) + .def( + "_recv_functions", + [](const DistAutogradContext& ctx) { + std::map funcs; + for (const auto& map_entry : ctx.recvFunctions()) { + funcs.emplace( + map_entry.first, + py::reinterpret_steal( + torch::autograd::functionToPyObject( + map_entry.second))); + } + return funcs; + }) .def("_send_functions", [](const DistAutogradContext& ctx) { - std::vector funcs; - for (const auto& sendFunction : ctx.sendFunctions()) { - funcs.push_back(py::reinterpret_steal( - torch::autograd::functionToPyObject(sendFunction))); + std::map funcs; + for (const auto& map_entry : ctx.sendFunctions()) { + funcs.emplace( + map_entry.first, + py::reinterpret_steal( + torch::autograd::functionToPyObject(map_entry.second))); } return funcs; }); @@ -50,6 +65,10 @@ PyObject* dist_autograd_init(PyObject* /* unused */) { return DistAutogradContainer::getInstance().releaseContext(context_id); }); + module.def("_get_max_id", []() { + return DistAutogradContainer::getInstance().getMaxId(); + }); + module.def( "_retrieve_context", [](int64_t context_id) -> const DistAutogradContext& { diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp index e85bad03f6a3f..3eb19d45d4006 100644 --- a/torch/csrc/distributed/autograd/utils.cpp +++ b/torch/csrc/distributed/autograd/utils.cpp @@ -1,24 +1,55 @@ #include +#include +#include +#include #include namespace torch { namespace distributed { namespace autograd { -std::shared_ptr addSendRpcBackward( - const std::vector& tensors) { +using torch::distributed::rpc::Message; + +void addSendRpcBackward( + DistAutogradContext& autogradContext, + const torch::distributed::rpc::AutogradMetadata& autogradMetadata, + std::vector& tensors) { // Attach the appropriate autograd edges. - std::shared_ptr grad_fn; if (torch::autograd::compute_requires_grad(tensors)) { - grad_fn = std::make_shared(); + auto grad_fn = std::make_shared(); grad_fn->set_next_edges(torch::autograd::collect_next_edges(tensors)); // Add the appropriate input metadata for the grad_fn. for (const auto& tensor : tensors) { grad_fn->add_input_metadata(tensor); } + + // Record the send autograd function in our current context. + autogradContext.addSendFunction( + grad_fn, autogradMetadata.autogradMessageId); + } +} + +DistAutogradContext* addRecvRpcBackward( + const torch::distributed::rpc::AutogradMetadata& autogradMetadata, + std::vector& tensors) { + if (torch::autograd::compute_requires_grad(tensors)) { + // Attach the tensors as inputs to the autograd function. + auto grad_fn = std::make_shared(); + for (auto& tensor : tensors) { + torch::autograd::set_history(tensor, grad_fn); + } + + // Now update the autograd context with the necessary information. + auto& autogradContainer = DistAutogradContainer::getInstance(); + // Initialize autograd context if necessary. + DistAutogradContext& autogradContext = autogradContainer.getOrCreateContext( + autogradMetadata.autogradContextId); + autogradContext.addRecvFunction( + grad_fn, autogradMetadata.autogradMessageId); + return &autogradContext; } - return grad_fn; + return nullptr; } } // namespace autograd diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h index 6e04605711b86..3126f6392b07f 100644 --- a/torch/csrc/distributed/autograd/utils.h +++ b/torch/csrc/distributed/autograd/utils.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include namespace torch { namespace distributed { @@ -9,12 +9,26 @@ namespace autograd { // This method is used to attach the 'send' autograd function to the autograd // graph when we use RPC. This method creates a new 'send' autograd function -// and attaches the provided tensors as next_edges to the 'send' function. +// and attaches the provided tensors as next_edges to the 'send' function. In +// addition to this, it also registers the send function in the provided +// autograd context. Finally, the RPC message is updated with appropriate +// autograd information for the recipient. +TORCH_API void addSendRpcBackward( + DistAutogradContext& autogradContext, + const torch::distributed::rpc::AutogradMetadata& autogradMetadata, + std::vector& tensors); + +// This method is used to attach the 'recv' autograd function to the autograd +// graph when we use RPC. This method creates a new 'recv' autograd function +// and attaches the provided tensors as inputs to the 'recv' function. It +// creates a new autograd context if needed and registers the 'recv' function +// with this context. // -// Returns a shared_ptr to the autograd function, so that we can hold a -// reference to it. -TORCH_API std::shared_ptr addSendRpcBackward( - const std::vector& tensors); +// Returns a pointer to the autograd context created (nullptr in case of no +// autograd information was needed.) +TORCH_API DistAutogradContext* addRecvRpcBackward( + const torch::distributed::rpc::AutogradMetadata& autogradMetadata, + std::vector& tensors); } // namespace autograd } // namespace distributed diff --git a/torch/csrc/distributed/rpc/functions.cpp b/torch/csrc/distributed/rpc/functions.cpp deleted file mode 100644 index 5d5d054888759..0000000000000 --- a/torch/csrc/distributed/rpc/functions.cpp +++ /dev/null @@ -1,118 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch { -namespace distributed { -namespace rpc { - -Message createException(const Message& request, const std::exception& e) { - const char* err = e.what(); - std::vector payload(err, err + strlen(err)); - return Message( - std::move(payload), - std::vector(), - MessageType::EXCEPTION, - request.id()); -} - -Message processRequestBlocking(Message&& request) { - switch (request.type()) { - case MessageType::SCRIPT_CALL: { - try { - ScriptCall sc = ScriptCall::fromMessage(request); - - // sc is only alive within this block, use reference to avoid copy - auto& stack = sc.stackRef(); - sc.op()->getOperation()(stack); - - AT_ASSERT( - stack.size() == 1, - "Return value of a builtin operator or a " - "TorchScript function should be a single IValue, got a vector of " - "size ", - stack.size()); - auto response = ScriptRet(std::move(stack.front())).toMessage(); - - response.setId(request.id()); - return response; - } catch (std::exception& e) { - return createException(request, e); - } - break; - } - case MessageType::PYTHON_CALL: { - try { - std::vector tensorTable; - auto payload = PythonRpcHandler::getInstance().generatePythonUDFResult( - request, tensorTable); - return Message( - std::move(payload), - std::move(tensorTable), - MessageType::PYTHON_RET, - request.id()); - } catch (std::exception& e) { - return createException(request, e); - } - break; - } - case MessageType::REMOTE_CALL: { - ScriptRemoteCall src = ScriptRemoteCall::fromMessage(request); - - auto rrefId = RRefId::fromIValue(src.retRRefId()); - auto forkId = ForkId::fromIValue(src.retForkId()); - TORCH_CHECK(rrefId != forkId, "Does not support remote call to self."); - - auto& ctx = RRefContext::getInstance(); - auto ownerRRef = ctx->getOrCreateOwnerRRef(rrefId); - - // TODO: make this asynchronous - // src is only alive within this block, use reference to avoid copy - auto& stack = src.stackRef(); - src.op()->getOperation()(stack); - AT_ASSERT( - stack.size() == 1, - "Return value of a builtin operator or a " - "TorchScript function should be a single IValue, got a vector of " - "size ", - stack.size()); - - ownerRRef->setValue(std::move(stack.front())); - return Message(); - } - case MessageType::RREF_FETCH_CALL: { - ScriptRRefFetchCall srf = ScriptRRefFetchCall::fromMessage(request); - // TODO: make this asynchronous - std::shared_ptr> rref = - RRefContext::getInstance()->getOrCreateOwnerRRef( - RRefId::fromIValue(srf.value())); - auto response = ScriptRRefFetchRet(rref->getValue()).toMessage(); - response.setId(request.id()); - return response; - } - case MessageType::RREF_USER_CREATE: { - ScriptRRefCreate sra = ScriptRRefCreate::fromMessage(request); - RRefContext::getInstance()->addFork(sra.valueRef()); - return Message(); - } - case MessageType::RREF_USER_DELETE: { - ScriptRRefDelete srd = ScriptRRefDelete::fromMessage(request); - RRefContext::getInstance()->delFork(srd.valueRef()); - return Message(); - } - default: { - AT_ERROR("Request type ", request.type(), " not supported."); - } - } -} - -} // namespace rpc -} // namespace distributed -} // namespace torch diff --git a/torch/csrc/distributed/rpc/functions.h b/torch/csrc/distributed/rpc/functions.h deleted file mode 100644 index d5f82885712f7..0000000000000 --- a/torch/csrc/distributed/rpc/functions.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace distributed { -namespace rpc { - -Message processRequestBlocking(Message&& message); - -Message createException(const Message& request, const std::exception& e); - -} // namespace rpc -} // namespace distributed -} // namespace torch diff --git a/torch/csrc/distributed/rpc/future_message.cpp b/torch/csrc/distributed/rpc/future_message.cpp index bd0c725986671..92aef393a5b99 100644 --- a/torch/csrc/distributed/rpc/future_message.cpp +++ b/torch/csrc/distributed/rpc/future_message.cpp @@ -26,13 +26,6 @@ void FutureMessage::markCompleted() { markCompleted(Message()); } -const Message& FutureMessage::message() { - std::unique_lock lock(mutex_); - TORCH_CHECK(completed(), "Cannot retrieve message before completed."); - - return message_; -} - bool FutureMessage::completed() const { return completed_; } @@ -44,17 +37,17 @@ void FutureMessage::addCallback(const FutureMessage::Callback& callback) { callback(message_); return; } - callbacks.push_back(callback); + callbacks_.push_back(callback); } void FutureMessage::fireCallbacks() { TORCH_CHECK(completed(), "Firing callbacks on incomplete FutureMessage."); - // There is no need to protect callbacks with the lock. + // There is no need to protect callbacks_ with the lock. // Once completed_ is set to true, no one can add new callback to the list. - for (auto& callback : callbacks) { + for (auto& callback : callbacks_) { callback(message_); } - callbacks.clear(); + callbacks_.clear(); } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/future_message.h b/torch/csrc/distributed/rpc/future_message.h index 675582ef87bc8..0009f310937d4 100644 --- a/torch/csrc/distributed/rpc/future_message.h +++ b/torch/csrc/distributed/rpc/future_message.h @@ -18,7 +18,6 @@ struct TORCH_API FutureMessage final { const Message& wait(); void markCompleted(Message message); void markCompleted(); - const Message& message(); bool completed() const; // If completed() the callback will be invoked in-place. @@ -27,10 +26,10 @@ struct TORCH_API FutureMessage final { private: void fireCallbacks(); - std::mutex mutex_; + mutable std::mutex mutex_; std::atomic_bool completed_{false}; // is this future complete std::condition_variable finished_cv_; - std::vector callbacks; + std::vector callbacks_; // TODO: make message_ an optional field, and get rid of UNKNOWN message type Message message_; }; diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 9d732e55f3bd6..8448ad3436470 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -1,8 +1,8 @@ #include -#include #include #include +#include #include #include #include @@ -30,9 +30,9 @@ PyObject* rpc_init(PyObject* /* unused */) { auto module = py::handle(dist_module).cast(); - auto workerId = shared_ptr_class_(module, "WorkerId") - .def_readonly("name", &WorkerId::name_) - .def_readonly("id", &WorkerId::id_); + auto workerInfo = shared_ptr_class_(module, "WorkerInfo") + .def_readonly("name", &WorkerInfo::name_) + .def_readonly("id", &WorkerInfo::id_); auto rpcAgent = shared_ptr_class_(module, "RpcAgent") @@ -43,13 +43,33 @@ PyObject* rpc_init(PyObject* /* unused */) { &RpcAgent::sync, py::call_guard()); - auto rref = - shared_ptr_class_(module, "RRef") - .def("owner", &RRef::owner, py::call_guard()) + auto pyRRef = + shared_ptr_class_(module, "RRef") + .def( + // not releasing GIL here to avoid context switch on getters + "is_owner", + &PyRRef::isOwner) + .def( + // not releasing GIL here to avoid context switch on getters + "owner", + &PyRRef::owner) .def( "to_here", - [&](RRef& rref) { return torch::jit::toPyObject(rref.toHere()); }, - py::call_guard()); + &PyRRef::toHere, + py::call_guard()) + .def( + "local_value", + &PyRRef::localValue, + py::call_guard()) + .def(py::pickle( + [](const PyRRef& self) { + // __getstate__ + return self.pickle(); + }, + [](py::tuple t) { // NOLINT + // __setstate__ + return PyRRef::unpickle(t); + })); auto futureMessage = shared_ptr_class_(module, "FutureMessage") @@ -65,14 +85,14 @@ PyObject* rpc_init(PyObject* /* unused */) { py::arg("process_group"), py::arg("num_send_recv_threads") = 4) .def( - "get_worker_id", - (const WorkerId& (ProcessGroupAgent::*)(void)const) & - RpcAgent::getWorkerId, + "get_worker_info", + (const WorkerInfo& (ProcessGroupAgent::*)(void)const) & + RpcAgent::getWorkerInfo, py::call_guard()) .def( - "get_worker_id", - (const WorkerId& (ProcessGroupAgent::*)(const std::string&)const) & - ProcessGroupAgent::getWorkerId, + "get_worker_info", + (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) & + ProcessGroupAgent::getWorkerInfo, py::call_guard()) .def( "join", @@ -83,14 +103,18 @@ PyObject* rpc_init(PyObject* /* unused */) { &ProcessGroupAgent::sync, py::call_guard()); - module.def("init_rref_context", [](std::shared_ptr agent) { + module.def("_init_rref_context", [](std::shared_ptr agent) { RRefContext::initInstance(std::move(agent)); }); + module.def("_destroy_rref_context", []() { + RRefContext::getInstance()->destroyInstance(); + }); + module.def( "invoke_rpc_builtin", [](RpcAgent& agent, - const WorkerId& dst, + const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs) { @@ -100,8 +124,8 @@ PyObject* rpc_init(PyObject* /* unused */) { module.def( "invoke_rpc_python_udf", [](RpcAgent& agent, - const WorkerId& dst, - const std::string& pickledPythonUDF, + const WorkerInfo& dst, + std::string& pickledPythonUDF, std::vector& tensors) { return pyRpcPythonUdf(agent, dst, pickledPythonUDF, tensors); }); @@ -109,13 +133,22 @@ PyObject* rpc_init(PyObject* /* unused */) { module.def( "invoke_remote_builtin", [](RpcAgent& agent, - const WorkerId& dst, + const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs) { return pyRemoteBuiltin(agent, dst, opName, args, kwargs); }); + module.def( + "invoke_remote_python_udf", + [](RpcAgent& agent, + const WorkerInfo& dst, + std::string& pickledPythonUDF, + std::vector& tensors) { + return pyRemotePythonUdf(agent, dst, pickledPythonUDF, tensors); + }); + Py_RETURN_TRUE; } diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index cbfa567c868b6..18a68b1eb9043 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -44,10 +44,18 @@ void Message::swap(Message& rhs) noexcept { std::swap(id_, rhs.id_); } +std::vector&& Message::movePayload() && { + return std::move(payload_); +} + const std::vector& Message::payload() const { return payload_; } +std::vector& Message::tensors() { + return tensors_; +} + const std::vector& Message::tensors() const { return tensors_; } @@ -57,22 +65,29 @@ const MessageType& Message::type() const { } bool Message::isRequest() const { - return MessageType::SCRIPT_CALL == type_ || - MessageType::PYTHON_CALL == type_ || MessageType::REMOTE_CALL == type_ || - MessageType::RREF_FETCH_CALL == type_ || - MessageType::RREF_USER_CREATE == type_ || - MessageType::RREF_USER_DELETE == type_; -} - -bool Message::requiresResponse() const { - return MessageType::SCRIPT_CALL == type_ || - MessageType::PYTHON_CALL == type_ || - MessageType::RREF_FETCH_CALL == type_; + return MessageType::SCRIPT_CALL == type_ || // dist.rpc on builtin ops + MessageType::PYTHON_CALL == type_ || // dist.rpc on Python UDFs + MessageType::SCRIPT_REMOTE_CALL == type_ || // dist.remote on builtin ops + MessageType::PYTHON_REMOTE_CALL == type_ || // dist.remote on Python UDFs + // RRef related internal messages + MessageType::SCRIPT_RREF_FETCH_CALL == type_ || + MessageType::PYTHON_RREF_FETCH_CALL == type_ || + MessageType::RREF_USER_DELETE == type_ || + MessageType::RREF_CHILD_ACCEPT == type_ || + MessageType::RREF_FORK_REQUEST == type_ || + // Autograd message + MessageType::MESSAGE_WITH_AUTOGRAD_REQ == type_; } bool Message::isResponse() const { - return MessageType::SCRIPT_RET == type_ || MessageType::PYTHON_RET == type_ || - MessageType::RREF_FETCH_RET == type_; + return MessageType::SCRIPT_RET == type_ || // ret of dist.rpc on builtin ops + MessageType::PYTHON_RET == type_ || // ret of dist.rpc on Python UDFs + MessageType::REMOTE_RET == type_ || // ret of dist.remote + MessageType::RREF_FETCH_RET == type_ || // ret on RRef::toHere() + MessageType::EXCEPTION == type_ || // propagate back exceptions + MessageType::RREF_ACK == type_ || // ret of other types + // Autograd response + MessageType::MESSAGE_WITH_AUTOGRAD_RESP == type_; } bool Message::isShutdown() const { diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index eea216f7354a7..f9d1e7c866b74 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -8,18 +8,36 @@ namespace distributed { namespace rpc { enum MessageType { + // messages for dist.rpc on builtin operators SCRIPT_CALL = 0, - SCRIPT_RET, - PYTHON_CALL, - PYTHON_RET, - REMOTE_CALL, - RREF_FETCH_CALL, - RREF_FETCH_RET, - RREF_USER_CREATE, - RREF_USER_DELETE, - SHUTDOWN, - EXCEPTION, - UNKNOWN + SCRIPT_RET = 1, + + // messages for dist.rpc on Python UDF + PYTHON_CALL = 2, + PYTHON_RET = 3, + + // messages for dist.remote on builtin operators and Python UDF + SCRIPT_REMOTE_CALL = 4, // A remote call on a builtin operator + PYTHON_REMOTE_CALL = 5, // A remote call on a Python UDF + REMOTE_RET = 6, // A remote call on a Python UDF + + // RRef related internal messages + SCRIPT_RREF_FETCH_CALL = 7, // A UserRRef fetches value from owner + PYTHON_RREF_FETCH_CALL = 8, // A UserRRef fetches value from owner + RREF_FETCH_RET = 9, // An OwnerRRef sends value to user + RREF_USER_DELETE = 10, // A UserRRef tells the owner to deref + RREF_FORK_REQUEST = 11, // A child UserRRef tells the owner about itself + RREF_CHILD_ACCEPT = 12, // A child UserRRef tells parent that owner knows it + RREF_ACK = 13, // ACK to internal RRef messages + + // Messages with autograd info + MESSAGE_WITH_AUTOGRAD_REQ = 14, + MESSAGE_WITH_AUTOGRAD_RESP = 15, + + // Other internal message types + SHUTDOWN = 16, + EXCEPTION = 17, + UNKNOWN = 18 }; // A message to be sent/received by an RpcAgent. @@ -38,8 +56,8 @@ enum MessageType { // request and response. Other implementation can ignore it // if they have their own ways to do matching. // -// Layers above ``RpcAgent`` only converts ScriptCall, ScriptRet, PythonCall, -// and PythonRet into a Message, and it is up to the RpcAgent +// Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall, +// and PythonUDFResp into a Message, and it is up to the RpcAgent // implementation to determine how to serialize a message. class TORCH_API Message final { public: @@ -62,12 +80,15 @@ class TORCH_API Message final { Message& operator=(Message&& rhs) &; void swap(Message& rhs) noexcept; + // Destructively retrieves the payload. + std::vector&& movePayload() &&; + const std::vector& payload() const; + std::vector& tensors(); const std::vector& tensors() const; const MessageType& type() const; bool isRequest() const; - bool requiresResponse() const; bool isResponse() const; bool isShutdown() const; diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index a8c96f0a52f4f..37b1651536c69 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -1,5 +1,7 @@ #include +#include #include +#include #include @@ -22,7 +24,7 @@ void serialize(const Message& message, std::ostream& os) { std::vector tensors = message.tensors(); // append payload as a tensor tensors.push_back(torch::from_blob(payload, payload_size, {torch::kChar})); - // append id as a tensor + // append id and autograd metadata as a tensor tensors.push_back(torch::tensor({message.id()}, {torch::kInt64})); torch::save(tensors, os); @@ -39,6 +41,7 @@ Message deserialize(MessageType type, std::istream& is) { auto payloadTensor = std::move(tensors.back()); tensors.pop_back(); + TORCH_INTERNAL_ASSERT(1, idTensor.numel()); int64_t id = idTensor.storage().data()[0]; std::vector payload(payloadTensor.numel()); @@ -71,18 +74,18 @@ std::vector ProcessGroupAgent::MessageCounter::snapshot() { //////////////////////// ProcessGroupAgent ///////////////////////////////// void ProcessGroupAgent::collectNames() { - const std::string& workerName = workerId_.name_; + const std::string& workerName = workerInfo_.name_; const auto worldSize = pg_->getSize(); // use c10d allgather to collect names torch::Tensor nameTensor = - torch::zeros({WorkerId::MAX_NAME_LEN}, torch::kChar); + torch::zeros({WorkerInfo::MAX_NAME_LEN}, torch::kChar); memcpy(nameTensor.storage().data(), workerName.c_str(), workerName.length()); std::vector inputName = {nameTensor}; std::vector> outputNames(1); for (int i = 0; i < worldSize; ++i) { outputNames[0].emplace_back( - torch::empty({WorkerId::MAX_NAME_LEN}, {torch::kChar})); + torch::empty({WorkerInfo::MAX_NAME_LEN}, {torch::kChar})); } pg_->allgather(outputNames, inputName)->wait(); @@ -106,8 +109,8 @@ ProcessGroupAgent::ProcessGroupAgent( std::shared_ptr pg, int numSendRecvThreads) : RpcAgent( - WorkerId(std::move(workerName), pg->getRank()), - processRequestBlocking), + WorkerInfo(std::move(workerName), pg->getRank()), + c10::guts::make_unique()), pg_(std::move(pg)), sendCounts_(pg_->getSize()), recvCounts_(pg_->getSize()), @@ -120,12 +123,12 @@ ProcessGroupAgent::ProcessGroupAgent( "ProcessGroupAgent requires world_size to " "be at least 2, but got ", nameMap_.size()); - auto workerRankIter = nameMap_.find(workerId_.name_); + auto workerRankIter = nameMap_.find(workerInfo_.name_); TORCH_CHECK( workerRankIter != nameMap_.end(), "Failed to resolve worker " "name ", - workerId_.name_, + workerInfo_.name_, " to a ProcessGroup rank."); TORCH_CHECK( pg_->getRank() == workerRankIter->second, @@ -140,9 +143,9 @@ ProcessGroupAgent::ProcessGroupAgent( tmpWorkerIds[entry.second] = entry.first; } - workerIds_.reserve(pg_->getSize()); + allWorkerInfo_.reserve(pg_->getSize()); for (int rank = 0; rank < (int)tmpWorkerIds.size(); ++rank) { - workerIds_.emplace_back(std::move(tmpWorkerIds[rank]), rank); + allWorkerInfo_.emplace_back(std::move(tmpWorkerIds[rank]), rank); } // construct PythonRpcHandler singleton here @@ -150,17 +153,17 @@ ProcessGroupAgent::ProcessGroupAgent( listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this); } -const WorkerId& ProcessGroupAgent::getWorkerId( +const WorkerInfo& ProcessGroupAgent::getWorkerInfo( const std::string& workerName) const { const auto idIter = nameMap_.find(workerName); TORCH_CHECK( idIter != nameMap_.end(), "Unknown destination worker ", workerName); - return workerIds_[idIter->second]; + return allWorkerInfo_[idIter->second]; } -const WorkerId& ProcessGroupAgent::getWorkerId(worker_id_t id) const { - return workerIds_[id]; +const WorkerInfo& ProcessGroupAgent::getWorkerInfo(worker_id_t id) const { + return allWorkerInfo_[id]; } void ProcessGroupAgent::join() { @@ -171,9 +174,13 @@ void ProcessGroupAgent::join() { // 2. A GLOO process cannot send message to itself. (there is an ongoing // effort to fix this problem). sync(); + std::unique_lock lock(futureMutex_); + futureCV_.wait(lock, [this] { return futures_.empty(); }); + lock.unlock(); + pg_->barrier()->wait(); int dst = (pg_->getRank() + 1) % pg_->getSize(); enqueueSend( - SendWork(workerIds_[dst], Message({}, {}, MessageType::SHUTDOWN))); + SendWork(allWorkerInfo_[dst], Message({}, {}, MessageType::SHUTDOWN))); threadPool_.waitWorkComplete(); listenerThread_.join(); } @@ -195,7 +202,6 @@ bool ProcessGroupAgent::hasPendingMessage() { std::vector inputSnapshot = { torch::from_blob(snapshot.data(), {2, worldSize}, {torch::kInt64})}; - // allgather both send and recv messages in one shot std::vector> outputSnapshots(1); @@ -241,8 +247,8 @@ void ProcessGroupAgent::sync() { } while (hasPendingMessage()); } -std::shared_ptr ProcessGroupAgent::sendImpl( - const WorkerId& to, +std::shared_ptr ProcessGroupAgent::send( + const WorkerInfo& to, Message&& message) { TORCH_CHECK( to.id_ != (worker_id_t)pg_->getRank(), @@ -256,7 +262,7 @@ std::shared_ptr ProcessGroupAgent::sendImpl( auto requestId = nextId(); auto future = std::make_shared(); - if (message.requiresResponse()) { + if (message.isRequest()) { { std::lock_guard lock{futureMutex_}; futures_[requestId] = future; @@ -268,14 +274,14 @@ std::shared_ptr ProcessGroupAgent::sendImpl( // NB: cannot directly pass ``to`` to the ``SendWork``, because it might no // longer be alive when the ``SendWork`` is executed. For example, the - // application could query the ``WorkerId`` using name through the - // ``RpcAgent::getWorkerId`` API, and pass the ``WorkerId`` back here, so we - // have C++ -> Python -> C++. For an asynchronous RPC, the ``WorkerId`` + // application could query the ``WorkerInfo`` using name through the + // ``RpcAgent::getWorkerInfo`` API, and pass the ``WorkerInfo`` back here, so + // we have C++ -> Python -> C++. For an asynchronous RPC, the ``WorkerInfo`` // reference on Python side could die before ``SendWork`` uses it, and Pybind // will not keep the Python reference alive even if it originally comes from - // the C++ land. Hence, we have to explicitly use the ``workerId`` in the C++ - // land. - enqueueSend(SendWork(workerIds_[to.id_], std::move(message))); + // the C++ land. Hence, we have to explicitly use the ``WorkerInfo`` in the + // C++ land. + enqueueSend(SendWork(allWorkerInfo_[to.id_], std::move(message))); return future; } @@ -337,21 +343,27 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) { (char*)payload.storage().data(), payload.numel())); Message message = deserialize(work.type_, ss); - - if (message.requiresResponse()) { - send(work.from_, cb_(std::move(message))); - } else if (message.isRequest()) { - cb_(std::move(message)); + if (message.isRequest()) { + send(work.from_, cb_->operator()(message)); } else if (message.isResponse()) { auto id = message.id(); + std::shared_ptr fm = nullptr; + { + std::lock_guard lock{futureMutex_}; + fm = futures_[id]; + } + // Not holding lock on markCompleted as this could run callbacks that + // call agent_->send + fm->markCompleted(std::move(message)); { std::lock_guard lock{futureMutex_}; - futures_[id]->markCompleted(std::move(message)); futures_.erase(id); } + futureCV_.notify_all(); } else { // TODO: pass the error back to the caller instead of crashing here. - AT_ERROR("unrecognized message type ", message.type()); + TORCH_INTERNAL_ASSERT( + false, "unrecognized message type ", message.type()); } recvCounts_.increment(work.from_.id_); @@ -374,7 +386,7 @@ void ProcessGroupAgent::listenLoop() { // FIXME: This LOG also prints warnings no InitGoogleLogging() was invoked // before logging, but it is not appropriate to call InitGoogleLogging() // here either. - LOG(INFO) << "Shutting down ProcessGroupAgent " << workerId_.name_ + LOG(INFO) << "Shutting down ProcessGroupAgent " << workerInfo_.name_ << std::endl; return; } @@ -382,7 +394,7 @@ void ProcessGroupAgent::listenLoop() { std::vector tensors = {torch::empty({size}, {torch::kChar})}; pg_->recv(tensors, srcRank, pg_->getRank())->wait(); - enqueueRecv(RecvWork(workerIds_[srcRank], type, std::move(tensors[0]))); + enqueueRecv(RecvWork(allWorkerInfo_[srcRank], type, std::move(tensors[0]))); } } diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index db78950b61115..a83cb5eaa8339 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -16,20 +15,20 @@ namespace rpc { // SendWork and RecvWork will be put into a task queue, and later picked up by // worker threads from the same ThreadPool. struct SendWork { - SendWork(const WorkerId& to, Message&& message) + SendWork(const WorkerInfo& to, Message&& message) : to_(to), message_(message) {} - const WorkerId& to_; + const WorkerInfo& to_; Message message_; }; // SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is // to allow us to run serialization/deserialization in the worker threads. struct RecvWork { - RecvWork(const WorkerId& from, MessageType type, torch::Tensor&& payload) + RecvWork(const WorkerInfo& from, MessageType type, torch::Tensor&& payload) : from_(from), type_(type), payload_(payload) {} - const WorkerId& from_; + const WorkerInfo& from_; const MessageType type_; torch::Tensor payload_; }; @@ -41,9 +40,9 @@ class ProcessGroupAgent : public RpcAgent { std::shared_ptr pg, int numSendRecvThreads = 4); - const WorkerId& getWorkerId(const std::string& workerName) const override; + const WorkerInfo& getWorkerInfo(const std::string& workerName) const override; - const WorkerId& getWorkerId(worker_id_t id) const override; + const WorkerInfo& getWorkerInfo(worker_id_t id) const override; void join() override; @@ -53,7 +52,7 @@ class ProcessGroupAgent : public RpcAgent { // This method wraps the destination information and the message into a // SendWork object, and put the SendWork into a queue. Another thread will // consume SendWork from the queue and send it out. - std::shared_ptr sendImpl(const WorkerId& to, Message&& message) + std::shared_ptr send(const WorkerInfo& to, Message&& message) override; private: @@ -100,13 +99,13 @@ class ProcessGroupAgent : public RpcAgent { bool hasPendingMessage(); int64_t nextId() { - return nextId_++; + return ++nextId_; } std::shared_ptr pg_; // worker name -> rank std::unordered_map nameMap_; - std::vector workerIds_; + std::vector allWorkerInfo_; // record the number of messages sent to and received from each peer. The recv // counter is only marked after the message is processed. Join uses allgather // to collect all counts from all peers, uses these counters to detect global @@ -130,7 +129,8 @@ class ProcessGroupAgent : public RpcAgent { // This is just a temporary solution for (2). ThreadPool threadPool_; std::unordered_map> futures_; - std::mutex futureMutex_; + mutable std::mutex futureMutex_; + mutable std::condition_variable futureCV_; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp new file mode 100644 index 0000000000000..54168bf73f81a --- /dev/null +++ b/torch/csrc/distributed/rpc/py_rref.cpp @@ -0,0 +1,135 @@ +#include + +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +namespace { + +// Constants below are used in PyRRef pickling and unpickling. PyRRef is +// converted into a py::tuple in pickling, and reconstructed from the py::tuple +// in unpickling. +constexpr int RFD_IDX = 0; // index of RRefForkData +constexpr int TYPE_IDX = 1; // index of type (py::object or IValue) + +// number of data fields in the py::tuple. +// NB: if more fields are added, make sure this field is also bumped +constexpr int RREF_TUPLE_SIZE = 2; + +} // namespace + +PyRRef::PyRRef(std::shared_ptr rref) : rref_(std::move(rref)) { + TORCH_CHECK(rref_, "PyRRef must not wrap nullptr"); +} + +bool PyRRef::isOwner() const { + return rref_->isOwner(); +} + +worker_id_t PyRRef::owner() const { + return rref_->owner(); +} + +py::object PyRRef::toHere() { + if (rref_->isOwner()) { + if (rref_->isPyObj()) { + const py::object& value = + std::static_pointer_cast>(rref_)->getValue(); + + { + // acquiring GIL as the return statement construct a new py::object from + // a const reference. + AutoGIL ag; + return value; + } + } else { + IValue value = + std::static_pointer_cast>(rref_)->getValue(); + + { + // acquiring GIL as torch::jit::toPyObject creates new py::object + // without grabbing the GIL. + AutoGIL ag; + return torch::jit::toPyObject(std::move(value)); + } + } + } else { + if (rref_->isPyObj()) { + // UserRRef::toHere() calls python_rpc_handler which acquires + // GIL. + return std::static_pointer_cast>(rref_)->toHere(); + } else { + IValue value = + std::static_pointer_cast>(rref_)->toHere(); + + { + // acquiring GIL as torch::jit::toPyObject creates new py::object + // without grabbing the GIL. + AutoGIL ag; + return torch::jit::toPyObject(std::move(value)); + } + } + } +} + +py::object PyRRef::localValue() { + TORCH_CHECK( + rref_->isOwner(), + "Cannot call localValue() on a non-local reference. Call it on ", + RRefContext::getInstance()->getWorkerName()); + + if (rref_->isPyObj()) { + const py::object& value = + std::dynamic_pointer_cast>(rref_)->getValue(); + + { + // acquiring GIL as the return statement construct a new py::object from + // a const reference. + AutoGIL ag; + return value; + } + } else { + auto value = + std::dynamic_pointer_cast>(rref_)->getValue(); + { + // acquiring GIL as torch::jit::toPyObject creates new py::object without + // grabbing the GIL. + AutoGIL ag; + return torch::jit::toPyObject(std::move(value)); + } + } +} + +py::tuple PyRRef::pickle() const { + auto& ctx = RRefContext::getInstance(); + // TODO: use a dispatch table to pickle/unpickle an RRef, and only only + // install the dispatch table only when there are indeed RPC activities. As + // a counter example, checkpointing a model with RRefs should not trigger + // forks to be added as a fork or a child. + auto rfd = ctx->prepareChildFork(rref_); + return py::make_tuple(rfd.toPyTuple(), rref_->isPyObj()); +} + +PyRRef PyRRef::unpickle(const py::tuple& t) { + TORCH_INTERNAL_ASSERT( + t.size() == RREF_TUPLE_SIZE, "Pickled RRef must contain 2 numbers."); + auto& ctx = RRefContext::getInstance(); + auto rfd = RRefForkData::fromPyTuple(t[RFD_IDX].cast()); + std::shared_ptr rref = nullptr; + bool isPyObj = t[TYPE_IDX].cast(); + if (isPyObj) { + rref = ctx->getOrCreateRRef(rfd); + } else { + rref = ctx->getOrCreateRRef(rfd); + } + + ctx->notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref); + return PyRRef(std::move(rref)); +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/py_rref.h b/torch/csrc/distributed/rpc/py_rref.h new file mode 100644 index 0000000000000..10c5841490765 --- /dev/null +++ b/torch/csrc/distributed/rpc/py_rref.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// Python wrapper of an RRef shared_ptr that supports Python +// pickle and unpickle. +class PyRRef { + public: + explicit PyRRef(std::shared_ptr rref); + + bool isOwner() const; + worker_id_t owner() const; + py::object toHere(); + py::object localValue(); + py::tuple pickle() const; + static PyRRef unpickle(const py::tuple& t); + + private: + std::shared_ptr rref_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 7314d2d1c831a..52cae25206d91 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -1,4 +1,22 @@ #include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include namespace torch { namespace distributed { @@ -41,16 +59,23 @@ std::shared_ptr matchBuiltinOp( ", kwargs: ", kwargs, ") to a builtin operator"); +} - // builtin operators. +void finishAcceptUserRRef(const Message& message) { + RRefContext::handleException(message); + auto rr = RemoteRet::fromMessage(message); + auto& ctx = RRefContext::getInstance(); + ctx->delPendingUser(rr->forkId()); } } // namespace -py::object toPyObj(const Message& message) { - switch (message.type()) { +using namespace torch::distributed::autograd; + +py::object toPyObjInternal(RpcCommandBase& rpc, MessageType messageType) { + switch (messageType) { case MessageType::SCRIPT_RET: { - ScriptRet ret = ScriptRet::fromMessage(message); + auto& ret = static_cast(rpc); Stack stack; stack.push_back(ret.value()); { @@ -61,32 +86,68 @@ py::object toPyObj(const Message& message) { } } case MessageType::PYTHON_RET: { - return PythonRpcHandler::getInstance().loadPythonUDFResult(message); + // TODO: Try to avoid a copy here. + auto& resp = static_cast(rpc); + + return PythonRpcHandler::getInstance().loadPythonUDFResult( + resp.pickledPayload(), resp.tensors()); } - case MessageType::EXCEPTION: { - std::string err(message.payload().begin(), message.payload().end()); - throw std::runtime_error(err); + case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: { + auto& rpcWithAutograd = static_cast(rpc); + + // Attach 'recv' autograd function. + addRecvRpcBackward( + rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors()); + + // Handle the original RPC. + auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); + return toPyObjInternal(rpcWithAutograd.wrappedRpc(), wrappedMessageType); } default: { - AT_ERROR("Unrecognized response message type ", message.type()); + AT_ERROR("Unrecognized response message type ", messageType); } } } +py::object toPyObj(const Message& message) { + return toPyObjInternal(*deserializeResponse(message), message.type()); +} + std::shared_ptr pyRpcBuiltin( RpcAgent& agent, - const WorkerId& dst, + const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs) { Stack stack; auto op = matchBuiltinOp(opName, args, kwargs, stack); - return agent.send(dst, ScriptCall(op, std::move(stack)).toMessage()); + auto scriptCall = c10::guts::make_unique(op, std::move(stack)); + auto& autogradContainer = DistAutogradContainer::getInstance(); + if (autogradContainer.hasValidContext()) { + // Retrieve the appropriate context to modify. + auto& autogradContext = autogradContainer.currentContext(); + + // Wrap the original rpc with autograd information. + AutogradMetadata autogradMetadata( + autogradContext.context_id(), autogradContainer.newAutogradMessageId()); + RpcWithAutograd rpcWithAutograd( + MessageType::MESSAGE_WITH_AUTOGRAD_REQ, + autogradMetadata, + std::move(scriptCall)); + + // Record autograd information for 'send'. + addSendRpcBackward( + autogradContext, autogradMetadata, rpcWithAutograd.tensors()); + + return agent.send(dst, std::move(rpcWithAutograd).toMessage()); + } else { + return agent.send(dst, std::move(*scriptCall).toMessage()); + } } -std::shared_ptr pyRemoteBuiltin( +PyRRef pyRemoteBuiltin( RpcAgent& agent, - const WorkerId& dst, + const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs) { @@ -94,28 +155,57 @@ std::shared_ptr pyRemoteBuiltin( auto op = matchBuiltinOp(opName, args, kwargs, stack); auto& ctx = RRefContext::getInstance(); - auto userRRef = ctx->createUserRRef(dst.id_); - agent.send( + // TODO: support creating RRefs on a local object. + TORCH_INTERNAL_ASSERT( + ctx->getWorkerId() != dst.id_, + "Does not support creating RRef on self yet."); + auto userRRef = ctx->createUserRRef(dst.id_); + auto fm = agent.send( dst, ScriptRemoteCall( - op, - std::move(stack), - userRRef->id().toIValue(), - userRRef->forkId().toIValue()) + op, std::move(stack), userRRef->rrefId(), userRRef->forkId()) .toMessage()); - return userRRef; + + ctx->addPendingUser(userRRef->forkId(), userRRef); + fm->addCallback(finishAcceptUserRRef); + return PyRRef(userRRef); } std::shared_ptr pyRpcPythonUdf( RpcAgent& agent, - const WorkerId& dst, - const std::string& pickledPythonUDF, + const WorkerInfo& dst, + std::string& pickledPythonUDF, std::vector& tensors) { - std::vector data(pickledPythonUDF.begin(), pickledPythonUDF.end()); - return agent.send( dst, - Message(std::move(data), std::move(tensors), MessageType::PYTHON_CALL)); + PythonUDFCall( + std::vector(pickledPythonUDF.begin(), pickledPythonUDF.end()), + tensors) + .toMessage()); +} + +PyRRef pyRemotePythonUdf( + RpcAgent& agent, + const WorkerInfo& dst, + std::string& pickledPythonUDF, + std::vector& tensors) { + auto& ctx = RRefContext::getInstance(); + // TODO: support creating RRefs on a local object. + TORCH_INTERNAL_ASSERT( + ctx->getWorkerId() != dst.id_, + "Does not support creating RRef on self yet."); + auto userRRef = ctx->createUserRRef(dst.id_); + auto fm = agent.send( + dst, + PythonRemoteCall( + SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors)), + userRRef->rrefId().toIValue(), + userRRef->forkId().toIValue()) + .toMessage()); + + ctx->addPendingUser(userRRef->forkId(), userRRef); + fm->addCallback(finishAcceptUserRRef); + return PyRRef(userRRef); } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/python_functions.h b/torch/csrc/distributed/rpc/python_functions.h index 37d1fb7743431..c8f55e35d2cb2 100644 --- a/torch/csrc/distributed/rpc/python_functions.h +++ b/torch/csrc/distributed/rpc/python_functions.h @@ -1,15 +1,8 @@ #pragma once #include -#include -#include +#include #include -#include -#include -#include -#include -#include -#include #include namespace torch { @@ -20,24 +13,30 @@ py::object toPyObj(const Message& message); std::shared_ptr pyRpcBuiltin( RpcAgent& agent, - const WorkerId& dst, + const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs); std::shared_ptr pyRpcPythonUdf( RpcAgent& agent, - const WorkerId& dst, - const std::string& pickledPythonUDF, + const WorkerInfo& dst, + std::string& pickledPythonUDF, std::vector& tensors); -std::shared_ptr pyRemoteBuiltin( +PyRRef pyRemoteBuiltin( RpcAgent& agent, - const WorkerId& dst, + const WorkerInfo& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs); +PyRRef pyRemotePythonUdf( + RpcAgent& agent, + const WorkerInfo& dst, + std::string& pickledPythonUDF, + std::vector& tensors); + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/python_remote_call.cpp b/torch/csrc/distributed/rpc/python_remote_call.cpp new file mode 100644 index 0000000000000..cdd307c3c4e08 --- /dev/null +++ b/torch/csrc/distributed/rpc/python_remote_call.cpp @@ -0,0 +1,53 @@ +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +PythonRemoteCall::PythonRemoteCall( + SerializedPyObj&& serializedPyObj, + at::IValue retRRefId, + at::IValue retForkId) + : serializedPyObj_(std::move(serializedPyObj)), + retRRefId_(std::move(retRRefId)), + retForkId_(std::move(retForkId)) {} + +Message PythonRemoteCall::toMessage() && { + std::vector ivalues = serializedPyObj_.toIValues(); + ivalues.emplace_back(retRRefId_); + ivalues.emplace_back(retForkId_); + + std::vector tensor_table; + auto payload = + jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table); + + return Message( + std::move(payload), + std::move(tensor_table), + MessageType::PYTHON_REMOTE_CALL); +} + +std::unique_ptr PythonRemoteCall::fromMessage( + const Message& message) { + auto payload = static_cast(message.payload().data()); + auto payload_size = message.payload().size(); + + auto value = + jit::unpickle(payload, payload_size, nullptr, &message.tensors()); + auto values = value.toTuple()->elements(); + + // remove the last element from values and convert it back to an RRef + auto retForkId = std::move(values.back()); + values.pop_back(); + auto retRRefId = std::move(values.back()); + values.pop_back(); + auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values)); + + return c10::guts::make_unique( + std::move(serializedPyObj), std::move(retRRefId), std::move(retForkId)); +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/python_remote_call.h b/torch/csrc/distributed/rpc/python_remote_call.h new file mode 100644 index 0000000000000..fbf255e20a762 --- /dev/null +++ b/torch/csrc/distributed/rpc/python_remote_call.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +class TORCH_API PythonRemoteCall : public RpcCommandBase { + public: + PythonRemoteCall( + SerializedPyObj&& serializedPyObj, + at::IValue retRRefId, + at::IValue retForkId); + + inline const SerializedPyObj& serializedPyObj() const { + return serializedPyObj_; + } + + inline const at::IValue& retRRefId() const { + return retRRefId_; + } + + inline const at::IValue& retForkId() const { + return retForkId_; + } + + Message toMessage() && override; + static std::unique_ptr fromMessage(const Message& message); + + private: + const SerializedPyObj serializedPyObj_; + const at::IValue retRRefId_; + const at::IValue retForkId_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.cpp b/torch/csrc/distributed/rpc/python_rpc_handler.cpp index 8f34b22f6d968..61775c24eeac2 100644 --- a/torch/csrc/distributed/rpc/python_rpc_handler.cpp +++ b/torch/csrc/distributed/rpc/python_rpc_handler.cpp @@ -10,6 +10,7 @@ PythonRpcHandler::PythonRpcHandler() { py::module::import("torch.distributed.internal_rpc_utils"); runUDFFunction_ = module.attr("run_python_udf_internal"); loadResultFunction_ = module.attr("load_python_udf_result_internal"); + serializeFunction_ = module.attr("serialize"); } PythonRpcHandler& PythonRpcHandler::getInstance() { @@ -18,23 +19,47 @@ PythonRpcHandler& PythonRpcHandler::getInstance() { } std::vector PythonRpcHandler::generatePythonUDFResult( - const Message& request, - std::vector& tensorTable) { + const std::vector& pickledPayload, + const std::vector& requestTensorTable, + std::vector& responseTensorTable) { AutoGIL ag; - auto pargs = py::bytes(request.payload().data(), request.payload().size()); + auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size()); TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr"); - py::tuple pres = runUDFFunction_(pargs, request.tensors()); + py::tuple pres = + serializeFunction_(runUDFFunction_(pargs, requestTensorTable)); const auto& presStr = pres[0].cast(); - tensorTable = pres[1].cast>(); + responseTensorTable = pres[1].cast>(); std::vector payload(presStr.begin(), presStr.end()); return payload; } -py::object PythonRpcHandler::loadPythonUDFResult(const Message& message) { +py::object PythonRpcHandler::loadPythonUDFResult( + const std::vector& pickledPayload, + const std::vector& tensorTable) { AutoGIL ag; - auto pargs = py::bytes(message.payload().data(), message.payload().size()); + auto pargs = py::bytes(pickledPayload.data(), pickledPayload.size()); TORCH_CHECK(loadResultFunction_ != nullptr, "loadResultFunction_ is nullptr"); - return loadResultFunction_(pargs, message.tensors()); + return loadResultFunction_(pargs, tensorTable); +} + +py::object PythonRpcHandler::runPythonUDF( + const SerializedPyObj& serializedObj) { + AutoGIL ag; + return runUDFFunction_( + py::bytes(serializedObj.payload_), serializedObj.tensors_); +} + +SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) { + AutoGIL ag; + py::tuple t = serializeFunction_(obj); + return SerializedPyObj( + t[0].cast(), t[1].cast>()); +} + +py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) { + AutoGIL ag; + return loadResultFunction_( + py::bytes(serializedObj.payload_), serializedObj.tensors_); } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.h b/torch/csrc/distributed/rpc/python_rpc_handler.h index b2283adee23d3..ce551153bc102 100644 --- a/torch/csrc/distributed/rpc/python_rpc_handler.h +++ b/torch/csrc/distributed/rpc/python_rpc_handler.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include namespace torch { @@ -18,11 +19,20 @@ class PYBIND11_EXPORT PythonRpcHandler { static PythonRpcHandler& getInstance(); // Execute python UDF, result is pickled to binary string std::vector generatePythonUDFResult( - const Message& request, - std::vector& tensorTable); + const std::vector& pickledPayload, + const std::vector& requestTensorTable, + std::vector& responseTensorTable); // Returned python UDF result is pickled binary string, so run python // function to unpickle the python UDF result and return py::object to user - py::object loadPythonUDFResult(const Message& message); + py::object loadPythonUDFResult( + const std::vector& pickledPayload, + const std::vector& tensorTable); + // Run a pickled Python UDF and return the result py::object + py::object runPythonUDF(const SerializedPyObj& serializedObj); + // Serialized a py::object into a string + SerializedPyObj serialize(const py::object& obj); + // Deserialize a string into a py::object + py::object deserialize(const SerializedPyObj& serializedObj); private: PythonRpcHandler(); @@ -35,6 +45,7 @@ class PYBIND11_EXPORT PythonRpcHandler { py::object runUDFFunction_; py::object loadResultFunction_; + py::object serializeFunction_; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/python_udf_call.cpp b/torch/csrc/distributed/rpc/python_udf_call.cpp new file mode 100644 index 0000000000000..1a99ba5051de9 --- /dev/null +++ b/torch/csrc/distributed/rpc/python_udf_call.cpp @@ -0,0 +1,37 @@ +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +PythonUDFCall::PythonUDFCall( + std::vector pickledPayload, + std::vector tensors) + : pickledPayload_(std::move(pickledPayload)), + tensors_(std::move(tensors)) {} + +Message PythonUDFCall::toMessage() && { + return Message( + std::move(pickledPayload_), + std::move(tensors_), + MessageType::PYTHON_CALL); +} + +std::unique_ptr PythonUDFCall::fromMessage( + const Message& message) { + return c10::guts::make_unique( + message.payload(), message.tensors()); +} + +const std::vector& PythonUDFCall::pickledPayload() const { + return pickledPayload_; +} + +const std::vector& PythonUDFCall::tensors() const { + return tensors_; +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/python_udf_call.h b/torch/csrc/distributed/rpc/python_udf_call.h new file mode 100644 index 0000000000000..b46162dc32a18 --- /dev/null +++ b/torch/csrc/distributed/rpc/python_udf_call.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// RPC call representing calling a Python UDF over RPC. +class TORCH_API PythonUDFCall final : public RpcCommandBase { + public: + explicit PythonUDFCall( + std::vector pickledPayload, + std::vector tensors); + + Message toMessage() && override; + + static std::unique_ptr fromMessage(const Message& message); + + const std::vector& pickledPayload() const; + + const std::vector& tensors() const; + + private: + std::vector pickledPayload_; + std::vector tensors_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/python_udf_resp.cpp b/torch/csrc/distributed/rpc/python_udf_resp.cpp new file mode 100644 index 0000000000000..31ce56dfd7cba --- /dev/null +++ b/torch/csrc/distributed/rpc/python_udf_resp.cpp @@ -0,0 +1,35 @@ +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +PythonUDFResp::PythonUDFResp( + std::vector pickledPayload, + std::vector tensors) + : pickledPayload_(std::move(pickledPayload)), + tensors_(std::move(tensors)) {} + +Message PythonUDFResp::toMessage() && { + return Message( + std::move(pickledPayload_), std::move(tensors_), MessageType::PYTHON_RET); +} + +std::unique_ptr PythonUDFResp::fromMessage( + const Message& message) { + return c10::guts::make_unique( + message.payload(), message.tensors()); +} + +const std::vector& PythonUDFResp::pickledPayload() const { + return pickledPayload_; +} + +const std::vector& PythonUDFResp::tensors() const { + return tensors_; +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/python_udf_resp.h b/torch/csrc/distributed/rpc/python_udf_resp.h new file mode 100644 index 0000000000000..dc644b47adba4 --- /dev/null +++ b/torch/csrc/distributed/rpc/python_udf_resp.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// RPC call representing the response of a Python UDF over RPC. +class TORCH_API PythonUDFResp final : public RpcCommandBase { + public: + explicit PythonUDFResp( + std::vector pickledPayload, + std::vector tensors); + + Message toMessage() && override; + + static std::unique_ptr fromMessage(const Message& message); + + const std::vector& pickledPayload() const; + + const std::vector& tensors() const; + + private: + std::vector pickledPayload_; + std::vector tensors_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/request_callback.cpp b/torch/csrc/distributed/rpc/request_callback.cpp new file mode 100644 index 0000000000000..cdb9a8993946b --- /dev/null +++ b/torch/csrc/distributed/rpc/request_callback.cpp @@ -0,0 +1,35 @@ +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +using namespace torch::distributed::autograd; + +namespace { + +Message createException(const Message& request, const std::exception& e) { + const char* err = e.what(); + std::vector payload(err, err + strlen(err)); + return Message( + std::move(payload), + std::vector(), + MessageType::EXCEPTION, + request.id()); +} + +} // anonymous namespace + +Message RequestCallback::operator()(Message& request) const { + try { + return processMessage(request); + } catch (std::exception& e) { + return createException(request, e); + } +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/request_callback.h b/torch/csrc/distributed/rpc/request_callback.h new file mode 100644 index 0000000000000..7771e13820bed --- /dev/null +++ b/torch/csrc/distributed/rpc/request_callback.h @@ -0,0 +1,32 @@ +#pragma once + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// Functor which is invoked to process an RPC message. This is an abstract class +// with some common functionality across all request handlers. Users need to +// implement this interface to perform the actual business logic. +class TORCH_API RequestCallback { + public: + // Invoke the callback. + Message operator()(Message& request) const; + + virtual ~RequestCallback() {} + + protected: + // RpcAgent implementation should invoke ``RequestCallback`` to process + // received requests. There is no restriction on the implementation's + // threading model. This function takes an rvalue reference of the Message + // object. It is expected to return the response message or message + // containing an exception. Different rpc agent implementations are expected + // to ensure delivery of the response/exception based on their implementation + // specific mechanisms. + virtual Message processMessage(Message& request) const = 0; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp new file mode 100644 index 0000000000000..c2221759f3e73 --- /dev/null +++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp @@ -0,0 +1,185 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +using namespace torch::distributed::autograd; + +std::unique_ptr RequestCallbackImpl::processRpc( + RpcCommandBase& rpc, + MessageType messageType) const { + // TODO: RpcCommandBase should have an abstract execute() method that we can + // call here instead of having another switch statement here. Even better we + // could have abstract classes RpcRequest and RpcResp which inherit from + // RpcCommandBase and RpcRequest declares the abstract method execute() that + // we can call here. RpcResponse could have an abstract method to convert it + // to a python object. + switch (messageType) { + case MessageType::SCRIPT_CALL: { + auto& scriptCall = static_cast(rpc); + + // sc is only alive within this block, use reference to avoid copy + auto& stack = scriptCall.stackRef(); + scriptCall.op()->getOperation()(stack); + + TORCH_INTERNAL_ASSERT( + stack.size() == 1, + "Return value of a builtin operator or a " + "TorchScript function should be a single IValue, got a vector of " + "size ", + stack.size()); + + return c10::guts::make_unique(std::move(stack.front())); + } + case MessageType::PYTHON_CALL: { + auto& pyCall = static_cast(rpc); + std::vector responseTensorTable; + auto payload = PythonRpcHandler::getInstance().generatePythonUDFResult( + pyCall.pickledPayload(), pyCall.tensors(), responseTensorTable); + return c10::guts::make_unique( + std::move(payload), std::move(responseTensorTable)); + } + case MessageType::SCRIPT_REMOTE_CALL: { + auto& src = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + + auto ownerRRef = ctx->getOrCreateOwnerRRef(src.retRRefId()); + + // TODO: make this asynchronous + // src is only alive within this block, use reference to avoid copy + auto& stack = src.stackRef(); + src.op()->getOperation()(stack); + TORCH_INTERNAL_ASSERT( + stack.size() == 1, + "Return value of a builtin operator or a " + "TorchScript function should be a single IValue, got a vector of " + "size ", + stack.size()); + + ownerRRef->setValue(std::move(stack.front())); + ctx->addForkOfOwner(src.retRRefId(), src.retForkId()); + return c10::guts::make_unique( + src.retRRefId(), src.retForkId()); + } + case MessageType::PYTHON_REMOTE_CALL: { + auto& prc = static_cast(rpc); + + auto rrefId = RRefId::fromIValue(prc.retRRefId()); + auto forkId = ForkId::fromIValue(prc.retForkId()); + auto& ctx = RRefContext::getInstance(); + + auto ownerRRef = ctx->getOrCreateOwnerRRef(rrefId); + ownerRRef->setValue( + PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj())); + ctx->addForkOfOwner(rrefId, forkId); + return c10::guts::make_unique(rrefId, forkId); + } + case MessageType::SCRIPT_RREF_FETCH_CALL: { + auto& srf = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + // TODO: make this asynchronous + std::shared_ptr> rref = + ctx->getOrCreateOwnerRRef(srf.rrefId()); + return c10::guts::make_unique( + RRefFetchRet({rref->getValue()})); + } + case MessageType::PYTHON_RREF_FETCH_CALL: { + auto& prf = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + // TODO: make this asynchronous + std::shared_ptr> rref = + ctx->getOrCreateOwnerRRef(prf.rrefId()); + SerializedPyObj result = + PythonRpcHandler::getInstance().serialize(rref->getValue()); + return c10::guts::make_unique( + RRefFetchRet(result.toIValues())); + } + case MessageType::RREF_USER_DELETE: { + auto& rud = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + ctx->delForkOfOwner(rud.rrefId(), rud.forkId()); + return c10::guts::make_unique(); + } + case MessageType::RREF_CHILD_ACCEPT: { + auto& rca = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + ctx->delPendingChild(rca.forkId()); + return c10::guts::make_unique(); + } + case MessageType::RREF_FORK_REQUEST: { + auto& rfr = static_cast(rpc); + auto& ctx = RRefContext::getInstance(); + ctx->addForkOfOwner(rfr.rrefId(), rfr.forkId()); + return c10::guts::make_unique(); + } + case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: { + auto& rpcWithAutograd = static_cast(rpc); + const auto& autogradMetadata = rpcWithAutograd.autogradMetadata(); + + // Attach 'recv' autograd function. + DistAutogradContext* autogradContext = addRecvRpcBackward( + rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors()); + + // Process the original RPC. + auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); + auto wrappedRpcResponse = + processRpc(rpcWithAutograd.wrappedRpc(), wrappedMessageType); + + // Wrap the response with autograd, need a new autograd message id for + // each send/recv pair. + auto& autogradContainer = DistAutogradContainer::getInstance(); + AutogradMetadata responseAutogradMetadata( + autogradMetadata.autogradContextId, + autogradContainer.newAutogradMessageId()); + + auto response = c10::guts::make_unique( + MessageType::MESSAGE_WITH_AUTOGRAD_RESP, + responseAutogradMetadata, + std::move(wrappedRpcResponse)); + + // Attach the 'send' autograd function if needed. + if (autogradContext != nullptr) { + addSendRpcBackward( + *autogradContext, responseAutogradMetadata, response->tensors()); + } + return std::move(response); + } + default: { + TORCH_INTERNAL_ASSERT( + false, "Request type ", messageType, " not supported."); + } + } +} + +Message RequestCallbackImpl::processMessage(Message& request) const { + std::unique_ptr rpc = deserializeRequest(request); + auto response = processRpc(*rpc, request.type()); + if (response == nullptr) { + return Message(); + } + auto responseMessage = std::move(*response).toMessage(); + responseMessage.setId(request.id()); + return responseMessage; +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/request_callback_impl.h b/torch/csrc/distributed/rpc/request_callback_impl.h new file mode 100644 index 0000000000000..37231761660cd --- /dev/null +++ b/torch/csrc/distributed/rpc/request_callback_impl.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +class TORCH_API RequestCallbackImpl : public RequestCallback { + public: + Message processMessage(Message& request) const override; + + private: + std::unique_ptr processRpc( + RpcCommandBase& rpc, + MessageType messageType) const; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp index 9725513107a5f..fb5767813fb47 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.cpp +++ b/torch/csrc/distributed/rpc/rpc_agent.cpp @@ -1,41 +1,20 @@ #include -#include -#include namespace torch { namespace distributed { namespace rpc { -constexpr size_t WorkerId::MAX_NAME_LEN; -using namespace torch::distributed::autograd; +constexpr size_t WorkerInfo::MAX_NAME_LEN; -RpcAgent::RpcAgent(WorkerId workerId, RequestCallback cb) - : workerId_(std::move(workerId)), cb_(std::move(cb)) {} +RpcAgent::RpcAgent(WorkerInfo workerId, std::unique_ptr cb) + : workerInfo_(std::move(workerId)), cb_(std::move(cb)) {} RpcAgent::~RpcAgent() = default; -const WorkerId& RpcAgent::getWorkerId() const { - return workerId_; +const WorkerInfo& RpcAgent::getWorkerInfo() const { + return workerInfo_; } -std::shared_ptr RpcAgent::send( - const WorkerId& to, - Message&& message) { - // Record appropriate autograd information before sending the message over the - // wire. - auto& autogradContainer = DistAutogradContainer::getInstance(); - if (autogradContainer.hasValidContext()) { - // Attach the appropriate autograd edges to the tensors found in the - // message. - auto grad_fn = addSendRpcBackward(message.tensors()); - - // Record the send function in our current context. - auto& currentContext = autogradContainer.currentContext(); - currentContext.addSendFunction(grad_fn); - } - - return sendImpl(to, std::forward(message)); -} } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index f5ebda81efa3f..35eb2b53402fe 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -11,9 +12,9 @@ namespace distributed { namespace rpc { // A globally unique ID to identify an RpcAgent -struct WorkerId { - WorkerId(std::string name, int id) - : WorkerId(std::move(name), (worker_id_t)id) { +struct WorkerInfo { + WorkerInfo(std::string name, int id) + : WorkerInfo(std::move(name), (worker_id_t)id) { TORCH_CHECK( id <= std::numeric_limits::max(), "RPC worker id ", @@ -21,7 +22,8 @@ struct WorkerId { " out of bound of int16_t."); } - WorkerId(std::string name, worker_id_t id) : name_(std::move(name)), id_(id) { + WorkerInfo(std::string name, worker_id_t id) + : name_(std::move(name)), id_(id) { bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0; bool validChar = std::find_if(name_.begin(), name_.end(), [](char c) { @@ -48,22 +50,11 @@ struct WorkerId { // will invoke the given ``RequestCallback`` to process received requests. It // should immediately become ready to serve request and accept response after // construction. -class RpcAgent; - -// RpcAgent implementation should invoke ``RequestCallback`` to process received -// requests. There is no restriction on the implementation's threading model. -// This function takes an rvalue reference of the Message object. -// It is expected to return the response message or message containing an -// exception. Different rpc agent implementations are expected to ensure -// delivery of the response/exception based on their implementation specific -// mechanisms. -using RequestCallback = std::function; - class RpcAgent { public: - // `WorkerId` is the globally unique identifier for this RpcAgent instance. It - // contains a ``name_`` field and an ``id_`` field. ``name_`` is the globally - // unique name for this ``RpcAgent``. It is up to the ``RpcAgent`` + // `WorkerInfo` is the globally unique identifier for this RpcAgent instance. + // It contains a ``name_`` field and an ``id_`` field. ``name_`` is the + // globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent`` // implementation to determine how to resolve names. ``id_`` is the globally // unique ID for this ``RpcAgent``. This should be determined by the // ``RpcAgent`` implementation. @@ -71,7 +62,7 @@ class RpcAgent { // ``RpcAgent`` base class makes no assumption on the thread-safeness of the // ``RequestCallback``. ``RpcAgent`` implementations need to make sure that // its threading model conform to ``RequestCallback``'s requirement. - RpcAgent(WorkerId id, RequestCallback cb); + RpcAgent(WorkerInfo id, std::unique_ptr cb); virtual ~RpcAgent(); @@ -82,18 +73,21 @@ class RpcAgent { // If ``message.isRequest()`` is true, the ``FutureMessage`` will be completed // when the response arrives. For other message types, the Future should be // ignored by the caller. - std::shared_ptr send(const WorkerId& to, Message&& message); + virtual std::shared_ptr send( + const WorkerInfo& to, + Message&& message) = 0; - // Return a reference to the ``WorkerId`` of this RpcAgent. + // Return a reference to the ``WorkerInfo`` of this RpcAgent. // NB: not using ``c10::optional`` here because we might // need to create a separate RPC API lib and avoid forcing all ``RpcAgent`` // implementations to depend on libtorch. - const WorkerId& getWorkerId() const; + const WorkerInfo& getWorkerInfo() const; - // Return a reference to the ``WorkerId`` of the given ``workerName``. - virtual const WorkerId& getWorkerId(const std::string& workerName) const = 0; + // Return a reference to the ``WorkerInfo`` of the given ``workerName``. + virtual const WorkerInfo& getWorkerInfo( + const std::string& workerName) const = 0; - virtual const WorkerId& getWorkerId(worker_id_t id) const = 0; + virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0; // Call sync and join all internal threads. This method should be called // before every RPC process exits. @@ -104,16 +98,9 @@ class RpcAgent { virtual void sync() = 0; protected: - const WorkerId workerId_; - - // Method that needs to be overridden by all implementations of this - // interface. The public send() method is responsible for common - // pre-processing shared across all implementations. - virtual std::shared_ptr sendImpl( - const WorkerId& to, - Message&& message) = 0; + const WorkerInfo workerInfo_; const std::string workerName_; - const RequestCallback cb_; + const std::unique_ptr cb_; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/rpc_command_base.h b/torch/csrc/distributed/rpc/rpc_command_base.h new file mode 100644 index 0000000000000..b992624d77f76 --- /dev/null +++ b/torch/csrc/distributed/rpc/rpc_command_base.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// Base class for all RPC request and responses. +class RpcCommandBase { + public: + // Need to override this to serialize the RPC. This should destructively + // create a message for the RPC (Hence the &&). + virtual Message toMessage() && = 0; + virtual ~RpcCommandBase() = 0; +}; + +inline RpcCommandBase::~RpcCommandBase() {} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rpc_with_autograd.cpp b/torch/csrc/distributed/rpc/rpc_with_autograd.cpp new file mode 100644 index 0000000000000..70296d8e868de --- /dev/null +++ b/torch/csrc/distributed/rpc/rpc_with_autograd.cpp @@ -0,0 +1,163 @@ +#include +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +constexpr int kAutogradMessageSize = 17; + +AutogradMetadata::AutogradMetadata( + int64_t autogradContextId_, + int64_t autogradMessageId_) + : autogradContextId(autogradContextId_), + autogradMessageId(autogradMessageId_) {} + +RpcWithAutograd::RpcWithAutograd( + MessageType messageType, + const AutogradMetadata& autogradMetadata, + std::unique_ptr wrappedRpc) + : messageType_(messageType), autogradMetadata_(autogradMetadata) { + TORCH_INTERNAL_ASSERT(wrappedRpc != nullptr, "wrappedRpc cannot be null!"); + TORCH_INTERNAL_ASSERT( + messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_REQ || + messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_RESP); + wrappedMessage_ = std::move(*wrappedRpc).toMessage(); + tensors_ = wrappedMessage_.tensors(); + wrappedMessageType_ = wrappedMessage_.type(); +} + +RpcWithAutograd::RpcWithAutograd( + MessageType messageType, + const AutogradMetadata& autogradMetadata, + std::unique_ptr wrappedRpc, + MessageType wrappedMessageType, + std::vector tensors) + : messageType_(messageType), + autogradMetadata_(autogradMetadata), + wrappedRpc_(std::move(wrappedRpc)), + wrappedMessageType_(wrappedMessageType), + tensors_(std::move(tensors)) { + TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!"); + TORCH_INTERNAL_ASSERT( + messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_REQ || + messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_RESP); +} + +Message RpcWithAutograd::toMessage() && { + auto messageId = wrappedMessage_.id(); + auto messageType = wrappedMessage_.type(); + + auto payload = std::move(wrappedMessage_).movePayload(); + TORCH_INTERNAL_ASSERT(!payload.empty()); + + // We append the message type (1 byte), autograd context id(8 bytes) and + // autograd message id(8 bytes) to the original message in network byte order + // (big endian). + size_t writableIndex = payload.size(); + + // Need 17 additional bytes. + payload.resize(payload.size() + kAutogradMessageSize); + + // Add message type. + payload[writableIndex++] = messageType; + + // Add autograd ids. + torch::utils::THP_encodeInt64Buffer( + reinterpret_cast(payload.data()) + writableIndex, + &autogradMetadata_.autogradContextId, + torch::utils::THPByteOrder::THP_BIG_ENDIAN, + 1); + writableIndex += sizeof(int64_t); + torch::utils::THP_encodeInt64Buffer( + reinterpret_cast(payload.data()) + writableIndex, + &autogradMetadata_.autogradMessageId, + torch::utils::THPByteOrder::THP_BIG_ENDIAN, + 1); + + return Message( + std::move(payload), std::move(tensors_), messageType_, messageId); +} + +std::unique_ptr RpcWithAutograd::fromMessage( + const Message& message) { + MessageType originalMessageType = message.type(); + TORCH_INTERNAL_ASSERT( + MessageType::MESSAGE_WITH_AUTOGRAD_REQ == originalMessageType || + MessageType::MESSAGE_WITH_AUTOGRAD_RESP == originalMessageType); + + std::vector tensors = message.tensors(); + int64_t messageId = message.id(); + // Decode message type, autograd context id and autograd message id. + auto payload = message.payload(); + TORCH_INTERNAL_ASSERT(payload.size() > kAutogradMessageSize); + + int64_t autogradContextId, autogradMessageId; + // autograd message id. + size_t indexToRead = payload.size() - sizeof(int64_t); + TORCH_INTERNAL_ASSERT(indexToRead >= 0); + torch::utils::THP_decodeInt64Buffer( + &autogradMessageId, + reinterpret_cast(payload.data()) + indexToRead, + torch::utils::THPByteOrder::THP_BIG_ENDIAN, + 1); + + // autograd context id. + indexToRead -= sizeof(int64_t); + TORCH_INTERNAL_ASSERT(indexToRead >= 0); + torch::utils::THP_decodeInt64Buffer( + &autogradContextId, + reinterpret_cast(payload.data()) + indexToRead, + torch::utils::THPByteOrder::THP_BIG_ENDIAN, + 1); + + // message type. + indexToRead -= 1; + TORCH_INTERNAL_ASSERT(indexToRead >= 0); + MessageType wrappedMessageType = + static_cast(payload[indexToRead]); + + // Remove the autograd information. + payload.resize(payload.size() - kAutogradMessageSize); + + // Create new message type and build wrapped RPC. + Message wrappedMessage( + std::move(payload), std::move(tensors), wrappedMessageType, messageId); + + std::unique_ptr wrappedRpc; + if (originalMessageType == MessageType::MESSAGE_WITH_AUTOGRAD_REQ) { + wrappedRpc = deserializeRequest(wrappedMessage); + } else { + wrappedRpc = deserializeResponse(wrappedMessage); + } + + return c10::guts::make_unique( + originalMessageType, + AutogradMetadata(autogradContextId, autogradMessageId), + std::move(wrappedRpc), + wrappedMessageType, + std::move(tensors)); +} + +std::vector& RpcWithAutograd::tensors() { + return tensors_; +} + +const AutogradMetadata& RpcWithAutograd::autogradMetadata() const { + return autogradMetadata_; +} + +RpcCommandBase& RpcWithAutograd::wrappedRpc() { + TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!"); + return *wrappedRpc_; +} + +MessageType RpcWithAutograd::wrappedMessageType() const { + return wrappedMessageType_; +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rpc_with_autograd.h b/torch/csrc/distributed/rpc/rpc_with_autograd.h new file mode 100644 index 0000000000000..41d3a409a0923 --- /dev/null +++ b/torch/csrc/distributed/rpc/rpc_with_autograd.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +struct TORCH_API AutogradMetadata { + AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId); + + // autogradContextId_ is a globally unique integer that identifies a + // particular distributed autograd pass. + int64_t autogradContextId; + // autogradMessageId_ is a globally unique integer that identifies a pair + // of send/recv autograd functions. + int64_t autogradMessageId; +}; + +// Represents an RPC that includes autograd information. This class basically +// wraps another `RpcCommandBase` object which represents the actual RPC and has +// additional autograd information associated with that RPC. +class TORCH_API RpcWithAutograd final : public RpcCommandBase { + public: + // Used when we are sending an RPC over the wire. + RpcWithAutograd( + MessageType messageType, + const AutogradMetadata& autogradMetadata, + std::unique_ptr wrappedRpc); + + // Used when receiving an RPC over the wire. + RpcWithAutograd( + MessageType messageType, + const AutogradMetadata& autogradMetadata, + std::unique_ptr wrappedRpc, + MessageType wrappedMessageType, + std::vector tensors); + + Message toMessage() && override; + + static std::unique_ptr fromMessage(const Message& message); + + // Retrieves tensors as part of this RPC, which need to be considered for + // autograd computations. + std::vector& tensors(); + + const AutogradMetadata& autogradMetadata() const; + + RpcCommandBase& wrappedRpc(); + + // Message type of the wrapped RPC. + MessageType wrappedMessageType() const; + + private: + // Message type for this call. + MessageType messageType_; + + AutogradMetadata autogradMetadata_; + std::unique_ptr wrappedRpc_; + + // Serialized message representing wrappedRpc_. Used mostly as a cache to + // avoid serializing the request twice. + Message wrappedMessage_; + + // message type of the wrappedMessage, this is stored separately since + // wrappedMessage_ is not always guaranteed to be populated. + MessageType wrappedMessageType_; + + // Tensors part of the wrappedRpc that need to be considered for autograd. + std::vector tensors_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rref.cpp b/torch/csrc/distributed/rpc/rref.cpp index e81281b3a7fa2..fc34f1b2d7974 100644 --- a/torch/csrc/distributed/rpc/rref.cpp +++ b/torch/csrc/distributed/rpc/rref.cpp @@ -1,11 +1,27 @@ #include + +#include #include -#include +#include namespace torch { namespace distributed { namespace rpc { +namespace { + +constexpr int OWNER_IDX = 0; // index of ownerId in the tuple +constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple +constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple +constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple +constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple +constexpr int PARENT_IDX = 5; // index of parent in the tuple + +// NB: if more fields are added, make sure this field is also bumped +constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple + +} // namespace + std::atomic RRefContext::nextLocalId_{0}; ////////////////////////// RRefForkData ///////////////////////////////// @@ -13,27 +29,47 @@ std::atomic RRefContext::nextLocalId_{0}; RRefForkData::RRefForkData( worker_id_t ownerId, const RRefId& rrefId, - const ForkId& forkId) - : ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId) {} - -at::IValue RRefForkData::toIValue() const { - std::vector ivalues = { - (int64_t)ownerId_, rrefId_.toIValue(), forkId_.toIValue()}; + const ForkId& forkId, + worker_id_t parent) + : ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId), parent_(parent) {} + +py::tuple RRefForkData::toPyTuple() const { + return py::make_tuple( + ownerId_, + rrefId_.createdOn_, + rrefId_.localId_, + forkId_.createdOn_, + forkId_.localId_, + parent_); +} - return c10::ivalue::Tuple::create(std::move(ivalues)); +RRefForkData RRefForkData::fromPyTuple(const py::tuple& t) { + TORCH_INTERNAL_ASSERT( + t.size() == RFD_TUPLE_SIZE, + "Pickled RRefForkData must contain 6 numbers."); + worker_id_t ownerId = t[OWNER_IDX].cast(); + // const reference will extend the lifetime of the temporary variable + const RRefId& rrefId = RRefId( + t[RREFID_ON_IDX].cast(), + t[RREFID_ID_IDX].cast()); + const RRefId& forkId = RRefId( + t[FORKID_ON_IDX].cast(), + t[FORKID_ID_IDX].cast()); + worker_id_t parent = t[PARENT_IDX].cast(); + return RRefForkData(ownerId, rrefId, forkId, parent); } RRefForkData RRefForkData::fromIValue(const at::IValue& ivalue) { auto ivalues = ivalue.toTuple()->elements(); - TORCH_CHECK( - ivalues.size() == 3, + TORCH_INTERNAL_ASSERT( + ivalues.size() == 4, "Constructing RRefForkData from ivalue " - "expects a GenericList of 3 elements, but got ", + "expects a GenericList of 4 elements, but got ", ivalues.size()); int64_t ownerId = ivalues[0].toInt(); - TORCH_CHECK( + TORCH_INTERNAL_ASSERT( ownerId < std::numeric_limits::max(), "RRefId createdOn out of range, got ", ownerId); @@ -41,7 +77,12 @@ RRefForkData RRefForkData::fromIValue(const at::IValue& ivalue) { RRefId rrefId = RRefId::fromIValue(ivalues[1]); ForkId forkId = ForkId::fromIValue(ivalues[2]); - return RRefForkData(ownerId, rrefId, forkId); + int64_t parent = ivalues[3].toInt(); + TORCH_INTERNAL_ASSERT( + parent < std::numeric_limits::max(), + "RRefId createdOn out of range, got ", + parent); + return RRefForkData(ownerId, rrefId, forkId, parent); } ////////////////////////////// RRef ///////////////////////////////////// @@ -49,71 +90,100 @@ RRefForkData RRefForkData::fromIValue(const at::IValue& ivalue) { RRef::RRef(worker_id_t ownerId, const RRefId& rrefId) : ownerId_(ownerId), rrefId_(rrefId) {} -worker_id_t RRef::owner() const { - return ownerId_; -} - -const RRefId& RRef::id() const { - return rrefId_; -} - -at::IValue RRef::fork() const { +RRefForkData RRef::fork() const { + auto& ctx = RRefContext::getInstance(); return RRefForkData( - ownerId_, rrefId_, RRefContext::getInstance()->genRRefId()) - .toIValue(); - // NB: does not support sharing RRefs between users - // TODO: notify the owner + ownerId_, rrefId_, ctx->genGloballyUniqueId(), ctx->getWorkerId()); } ////////////////////////// UserRRef ///////////////////////////////////// -UserRRef::UserRRef( +template +UserRRef::UserRRef( worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId) : RRef(ownerId, rrefId), forkId_(forkId) { - AT_ASSERT( - !(forkId_ == rrefId_), - "User RRef's fork ID should not be the same as its rref Id"); - if (RRefContext::getInstance()->getWorkerId() == rrefId_.createdOn_) { - // creator user, notify owner. - auto& agent = RRefContext::getInstance()->agent(); - agent->send( - agent->getWorkerId(ownerId_), - ScriptRRefCreate(RRefForkData(ownerId_, rrefId_, forkId_).toIValue()) - .toMessage()); - } else { - AT_ERROR("Does not support sharing RRefs between users yet"); - } + // Do nothing, + // (1) If this UserRRef is a fork of an existing RRef, RRefContext will send + // a RREF_FORK_REQUEST message to the owner. + // (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will + // properly notify the owner. } -UserRRef::~UserRRef() { +template +UserRRef::~UserRRef() { + // TODO: queue this in RRefContext instead of doing it here. auto& ctx = RRefContext::getInstance(); if (ctx->getWorkerId() != ownerId_) { - ctx->agent()->send( - ctx->agent()->getWorkerId(ownerId_), - ScriptRRefDelete(RRefForkData(ownerId_, rrefId_, forkId_).toIValue()) - .toMessage()); + auto fm = ctx->agent()->send( + ctx->agent()->getWorkerInfo(ownerId_), + RRefUserDelete(rrefId_, forkId_).toMessage()); + + fm->addCallback( + [](const Message& message) { RRefContext::handleException(message); }); } } -const ForkId& UserRRef::forkId() const { +template +const ForkId& UserRRef::forkId() const { return forkId_; } -bool UserRRef::isOwner() const { - return false; +template <> +IValue UserRRef::toHere() { + auto& agent = RRefContext::getInstance()->agent(); + std::shared_ptr fm = agent->send( + agent->getWorkerInfo(ownerId_), + ScriptRRefFetchCall(rrefId()).toMessage()); + const Message& message = fm->wait(); + RRefContext::handleException(message); + auto rfr = RRefFetchRet::fromMessage(message); + TORCH_INTERNAL_ASSERT( + rfr->values().size() == 1, + "RRef of IValue should contain a single IValue, but got ", + rfr->values().size()); + return rfr->values().front(); } -IValue UserRRef::toHere() { +template <> +py::object UserRRef::toHere() { auto& agent = RRefContext::getInstance()->agent(); std::shared_ptr fm = agent->send( - agent->getWorkerId(ownerId_), - ScriptRRefFetchCall(id().toIValue()).toMessage()); - auto srv = ScriptRRefFetchRet::fromMessage(fm->wait()); - return srv.value(); + agent->getWorkerInfo(ownerId_), + PythonRRefFetchCall(rrefId()).toMessage()); + const Message& message = fm->wait(); + RRefContext::handleException(message); + auto rfr = RRefFetchRet::fromMessage(message); + return PythonRpcHandler::getInstance().deserialize( + SerializedPyObj::fromIValues(rfr->values())); } +template class UserRRef; +template class UserRRef; + +////////////////////////// OwnerRRef ///////////////////////////////////// + +template +const T& OwnerRRef::getValue() const { + // TODO: use callback to make this non-blocking + std::unique_lock lock(mutex_); + valueCV_.wait(lock, [this] { return value_.has_value(); }); + return value_.value(); +} + +template +void OwnerRRef::setValue(T&& value) { + { + std::lock_guard lock(mutex_); + value_ = std::move(value); + } + valueCV_.notify_all(); +} + +template class OwnerRRef; +template class OwnerRRef; + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/rref.h b/torch/csrc/distributed/rpc/rref.h index 03fa11a25c629..53d557a49deed 100644 --- a/torch/csrc/distributed/rpc/rref.h +++ b/torch/csrc/distributed/rpc/rref.h @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -13,71 +14,245 @@ namespace rpc { class RRef; class RRefContext; +template class UserRRef; // Represents fork of an RRef to be sent over the wire. -// -// In order to preserve correctness of reference counting, each RRefForkData -// **MUST** be deserialized into a RRef. This means that if RRefForkData is to -// be transferred across the network, we need the guarantee that the message -// will *eventually* get to the peer, and that the peer will create a RRef out -// of it. Therefore, no constructor of RRefForkData is exposed, and -// applications should never directly use RRefForkData. All construction are -// done within ``RRef`` and ``RRefContext``. struct RRefForkData { - at::IValue toIValue() const; + py::tuple toPyTuple() const; + static RRefForkData fromPyTuple(const py::tuple& obj); + + const worker_id_t ownerId_; + const RRefId rrefId_; + const ForkId forkId_; + const worker_id_t parent_; private: friend class RRef; friend class RRefContext; + template friend class UserRRef; RRefForkData( worker_id_t ownerId, const RRefId& rrefId_, - const ForkId& forkId_); + const ForkId& forkId_, + worker_id_t parent); static RRefForkData fromIValue(const at::IValue&); - - const worker_id_t ownerId_; - const RRefId rrefId_; - const ForkId forkId_; }; static_assert( C10_IS_TRIVIALLY_COPYABLE(RRefForkData), "RRefForkData must be trivially copyable"); +// Note [RRef Protocol] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// [Background] +// +// RRef stands for Remote REFerence. Each RRef is owned by a single worker +// (i.e., owner) and can be used by multiple users. The owner stores the real +// data referenced by its RRefs. RRef needs to support fast and scalable RPC. +// Hence, in the design, we avoid using a single global master to keep RRef +// states, instead owners will keep track of the global reference counts +// for its RRefs. Every RRef can be uniquely identified by a global RRefId, +// which is assigned at the time it is first created either on a user or on the +// owner. +// +// On the owner worker, there is only one OwnerRRef instance, which contains the +// real data, while on user workers, there can be as many UserRRefs as +// necessary, and UserRRef does not hold the data. All usage on the OwnerRRef +// should retrieve the unique OwnerRRef instance using the globally unique +// RRefId. //A UserRRef will be created when it is used as an argument or return +// value in dist.rpc or dist.remote call, but RRef forking and reference +// counting (RC) are completely transparent to applications. Every UserRRef will +// also have its globally unique ForkId. +// +// [Assumptions] +// +// 1. Transient Network Failures +// +// TODO: current RRef implementation does not tolerate failures +// +// The RRef design aims to handle transient network failures by retrying +// messages. Node crashes or permanent network partition is beyond the scope. +// When those incidents occur, the application may take down all workers, revert +// to the previous checkpoint, and resume training. +// +// 2. Non-idempotent UDFs +// +// We assume UDFs are not idempotent and therefore cannot be retried. However, +// internal RRef control messages will be made idempotent and retryable. +// +// TODO: RRef internal messages are not yet idempotent +// +// 3. Out of Order Message Delivery +// +// We do not assume message delivery order between any pair of nodes, because +// both sender and receiver are using multiple threads. There is no guarantee on +// which message will be processed first. +// +// [RRef Lifetime] +// +// The goal of the protocol is to delete an OwnerRRef at an appropriate time. +// The right time to delete an OwnerRRef is when there are no living UserRRefs +// and Python GC also agrees to delete the OwnerRRef instance on the owner. The +// tricky part is to determine if there are any living UserRRefs. +// +// A user can get a UserRRef in three situations: +// +// (1). Receiving a UserRRef from the owner. +// (2). Receiving a UserRRef from another user. +// (3). Creating a new UserRRef owned by another worker. +// +// (1) is the simplest case where the owner initiates the fork, and hence it can +// easily increment local RC. The only requirement is that any UserRRef must +// notify the owner before destruction. Hence, we need the first guarantee: +// +// G1. The owner will be notified when any UserRRef is deleted. +// +// As messages might come delayed or out-of-order, we need more one guarantee to +// make sure the delete message is not sent out too soon. Let us first introduce +// a new concept. If A sends an RPC to B that involves an RRef, we call the RRef +// on A the parent RRef and the RRef on B the child RRef. +// +// G2. Parent RRef cannot be deleted until the child RRef is confirmed by the +// owner. +// +// Under (1), where the caller is UserRRef and callee is OwnerRRef, it simply +// means that the user will not send out the delete message until all previous +// messages are ACKed. Note that ACKed does not mean the owner finishes +// executing the function, instead, it only means the owner has retrieved its +// local OwnerRRef and about to pass it to the function, which is sufficient to +// keep the OwnerRRef alive even if the delete message from the user arrives at +// the owner before the function finishes execution. +// +// With (2) and (3), it is possible that the owner only partially knows the RRef +// fork graph or not even knowing it at all. For example, the RRef could be +// constructed on a user, and before the owner receives the RPC call, the +// creator user might have already shared the RRef with other users, and those +// users could further share the RRef. One invariant is that the fork graph of +// any RRef is always a tree rooted at the owner, because forking an RRef always +// creates a new RRef instance, and hence every RRef has a single parent. One +// nasty detail is that when an RRef is created on a user, technically the owner +// is not its parent but we still consider it that way and it does not break the +// argument below. +// +// The owner's view on any node (fork) in the tree has three stages: +// +// 1) unknown -> 2) known -> 3) deleted. +// +// The owner's view on the entire tree keeps changing. The owner deletes its +// OwnerRRef instance when it thinks there are no living UserRRefs, i.e., when +// OwnerRRef is deleted, all UserRRefs could be either indeed deleted or +// unknown. The dangerous case is when some forks are unknown and others are +// deleted. +// +// G2 trivially guarantees that no parent UserRRef Y can be deleted before the +// owner knows all of Y's children UserRRefs. +// +// However, it is possible that the child UserRRef Z may be deleted before the +// owner knows its parent Y. More specifically, this can happen when all of Z's +// messages are processed by the owner before all messages from Y, including the +// delete message. Nevertheless, this does not cause any problem. Because, at +// least one of Y's ancestor will be alive, and it will prevent the owner from +// deleting the OwnerRRef. Consider the following example: (NB: this scenario +// will no longer relevant when we block UDF until all RRefs are confirmed by +// the owner) +// +// OwnerRRef -> A -> Y -> Z +// +// OwnerRRef forks to A, then A forks to Y, and Y forks to Z. Z can be deleted +// without OwnerRRef knowing Y. However, the OwnerRRef will at least know A, as +// the owner directly forks the RRef to A. A won't die before the owner knows Y. +// +// Things get a little trickier if the RRef is created on a user: +// +// OwnerRRef +// ^ +// | +// A -> Y -> Z +// +// If Z calls to_here on the UserRRef, the owner at least knows A when Z is +// deleted, because otherwise to_here wouldn't finish. If Z does not call +// to_here, it is possible that the owner receives all messages from Z before +// any message from A and Y. In this case, as the real data of the OwnerRRef has +// not been created yet, there is nothing to be deleted either. It is the same +// as Z does not exist at all Hence, it's still OK. +// +// See #26759 for more details and discussions. +// // TODO: make RRef an IValue, and edit createStackForSchema accordingly +// TODO: make RRef system messages idempotent and retry on failures. +// +// ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``. +// Each ``RRef`` has a globally unique ``RRefId``. class RRef { public: // RRef is made NOT copyable NOT movable to prevent messing up reference - // counting + // counting. RRef(const RRef& other) = delete; RRef(RRef&& other) = delete; + RRef& operator=(RRef&& other) = delete; virtual ~RRef() = default; - worker_id_t owner() const; - const RRefId& id() const; - IValue fork() const; + // returns the worker id of the owner + inline worker_id_t owner() const { + return ownerId_; + } + + // Returns the globally unique RRefId of this RRef + inline const RRefId& rrefId() const { + return rrefId_; + } + // Returns true if this is the ``OwnerRRef`` virtual bool isOwner() const = 0; - virtual IValue toHere() = 0; + + // returns true if this RRef holds an py::object, false if IValue + virtual bool isPyObj() = 0; protected: + friend class RRefContext; + RRef(worker_id_t ownerId, const RRefId& rrefId); + RRefForkData fork() const; + const worker_id_t ownerId_; const RRefId rrefId_; }; +// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user +// also has a globally unique ``ForkId`` to identify this user. ``UserRRef`` +// never owns the real value, the only way to get the value of the ``RRef`` is +// to call ``to_here()`` and get a copy.. +template class UserRRef final : public RRef { public: + UserRRef(const UserRRef& other) = delete; + UserRRef(UserRRef&& other) = delete; + UserRRef& operator=(const UserRRef& other) = delete; + UserRRef& operator=(UserRRef&& other) = delete; + + inline bool isOwner() const override { + return false; + } + + inline bool isPyObj() override { + return std::is_same::value; + } + + // Returns the globally unique ForkId of this RRef const ForkId& forkId() const; - bool isOwner() const override; - IValue toHere() override; + // Get of copy of the value from the ``OwnerRRef``. If the value is not ready + // yet, this call will block. + T toHere(); + + // Upon destruction, this ``UserRRef`` will tell the owner to deref. ~UserRRef() override; private: @@ -93,28 +268,27 @@ class UserRRef final : public RRef { template class OwnerRRef final : public RRef { public: - bool isOwner() const override { + OwnerRRef(const OwnerRRef& other) = delete; + OwnerRRef(OwnerRRef&& other) = delete; + OwnerRRef& operator=(const OwnerRRef& other) = delete; + OwnerRRef& operator=(OwnerRRef&& other) = delete; + + inline bool isOwner() const override { return true; } - T getValue() const { - // TODO: use callback to make this non-blocking - std::unique_lock lock(mutex_); - valueCV_.wait(lock, [this] { return value_.has_value(); }); - return value_.value(); + inline bool isPyObj() override { + return std::is_same::value; } - void setValue(T&& value) { - { - std::lock_guard lock(mutex_); - value_ = std::move(value); - } - valueCV_.notify_all(); - } + // Get a constant reference of the real value. This method will block if the + // value is not ready. This method does not need GIL as it does not create + // any new py::object. + const T& getValue() const; - IValue toHere() override { - AT_ERROR("OwnerRRef does not support toHere(), use getValue() instead."); - } + // Set the value of this ``OwnerRRef``. This method does not need GIL as it + // does not create any new py::object. + void setValue(T&& value); private: friend class RRefContext; @@ -122,9 +296,6 @@ class OwnerRRef final : public RRef { OwnerRRef(worker_id_t ownerId, const RRefId& rrefId) : OwnerRRef(ownerId, rrefId, {}) {} - OwnerRRef(OwnerRRef&& other) noexcept - : OwnerRRef(other.owner(), other.id(), std::move(other.value_)) {} - OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, c10::optional value) : RRef(ownerId, rrefId) { value_ = std::move(value); diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp index dc7302e8adce9..5f04af9cbf30c 100644 --- a/torch/csrc/distributed/rpc/rref_context.cpp +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -1,10 +1,13 @@ #include +#include + +#include namespace torch { namespace distributed { namespace rpc { -std::unique_ptr RRefContext::context_; +std::unique_ptr RRefContext::context_ = nullptr; void RRefContext::initInstance(std::shared_ptr agent) { TORCH_CHECK(!RRefContext::context_, "Can only initialize RRefContext once."); @@ -20,50 +23,291 @@ std::unique_ptr& RRefContext::getInstance() { return RRefContext::context_; } +void RRefContext::destroyInstance() { + RRefContext::getInstance()->checkRRefLeaks(); + RRefContext::context_.reset(); +} + +void RRefContext::handleException(const Message& message) { + if (message.type() == MessageType::EXCEPTION) { + // TODO: allow users to register an error handler and call it here. + std::string err(message.payload().begin(), message.payload().end()); + VLOG(1) << "Got exception: " << err << std::endl << std::flush; + throw std::runtime_error(err); + } +} + RRefContext::RRefContext(std::shared_ptr agent) : agent_(std::move(agent)) {} -worker_id_t RRefContext::getWorkerId() const { - return agent_->getWorkerId().id_; +RRefContext::~RRefContext() { + if (!owners_.empty()) { + AutoGIL ag; + owners_.clear(); + } +} + +void RRefContext::checkRRefLeaks() { + if (!forks_.empty()) { + std::stringstream ss; + for (auto& entry : forks_) { + const RRefId& rrefId = entry.first; + for (const auto& forkId : entry.second) { + ss << "Leaking RRef " << rrefId << " with fork Id " << forkId + << std::endl; + } + } + AT_ERROR(ss.str()); + } } -RRefId RRefContext::genRRefId() { - return RRefId(getWorkerId(), nextLocalId_++); +template +std::shared_ptr> RRefContext::createUserRRef(worker_id_t ownerId) { + TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner."); + return createUserRRef( + ownerId, genGloballyUniqueId(), genGloballyUniqueId()); } -const std::shared_ptr& RRefContext::agent() const { - return agent_; +template std::shared_ptr> RRefContext::createUserRRef( + worker_id_t ownerId); + +template std::shared_ptr> RRefContext::createUserRRef< + py::object>(worker_id_t ownerId); + +template +std::shared_ptr> RRefContext::createUserRRef( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId) { + TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef."); + // RRefContext does not track user RRefs, it will be destructed when there + // is no shared_ptrs pointing to it. + // + // NB: cannot use make_shared here as the constructor of UserRRef is private. + // NB: This UserRRef has not been confirmed by the owner yet. This function's + // call site is responsible for adding this UserRRef to pendingUsers_. + // Currently, there are two call sites. + // (1) The creator user in python_functions.cpp + // (2) The callee user in RRefContext::notifyOwnerAndParentOfFork. + // + // The reason for not adding the pending user here is to put addPendingUser() + // close to where the RPC occurs, and it is more clear to pair it with + // deletePendingUser() in the response callback at the call site. + return std::shared_ptr>(new UserRRef(ownerId, rrefId, forkId)); } -void RRefContext::addFork(const at::IValue& value) { - auto rfd = RRefForkData::fromIValue(value); - AT_ASSERT( - rfd.ownerId_ == getWorkerId(), - "RRef user should never receive fork notification."); +template std::shared_ptr> RRefContext::createUserRRef( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId); + +template std::shared_ptr> RRefContext::createUserRRef< + py::object>( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId); + +template +std::shared_ptr RRefContext::getOrCreateRRef(const RRefForkData& rfd) { + auto& ownerId = rfd.ownerId_; + auto& rrefId = rfd.rrefId_; + auto& forkId = rfd.forkId_; + if (ownerId == getWorkerId()) { + return getOrCreateOwnerRRef(rrefId); + } else { + return createUserRRef(ownerId, rrefId, forkId); + } +} + +template std::shared_ptr RRefContext::getOrCreateRRef( + const RRefForkData& rfd); + +template std::shared_ptr RRefContext::getOrCreateRRef( + const RRefForkData& rfd); + +template +std::shared_ptr> RRefContext::getOrCreateOwnerRRef( + const RRefId& rrefId) { std::lock_guard lock(mutex_); - auto& rrefForks = forks_[rfd.rrefId_]; - AT_ASSERT( - rrefForks.find(rfd.forkId_) == rrefForks.end(), - "Got fork notification twice on the same RRef ", - rfd.rrefId_); - rrefForks.insert(rfd.forkId_); + const auto iter = owners_.find(rrefId); + if (iter == owners_.end()) { + // Scenario (1) the first time this owner knows about this RRef + // + // NB: cannot use make_shared here as the constructor of OwnerRRef is + // private. + auto rref = + std::shared_ptr>(new OwnerRRef(getWorkerId(), rrefId)); + owners_[rref->rrefId()] = rref; + return rref; + + } else { + // Scenario (2) retrieving an existing RRef + return std::static_pointer_cast>(iter->second); + } +} + +template std::shared_ptr> RRefContext::getOrCreateOwnerRRef< + IValue>(const RRefId& rrefId); + +template std::shared_ptr> RRefContext:: + getOrCreateOwnerRRef(const RRefId& rrefId); + +RRefForkData RRefContext::prepareChildFork(const std::shared_ptr& rref) { + auto rfd = rref->fork(); + if (rref->isOwner()) { + // Note [Early Fork Registration] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // If the parent (caller) is the owner, directly register the fork, instead + // of waiting for another RREF_FORK_REQUEST or RREF_CHILD_ACCEPT message. An + // Alternative is adding the fork when the callee user ACKs. However, before + // that, the owner still have to adds the OwnerRRef into some map to keep it + // alive (e.g., in pendingChildren_). Hence, adding the fork here or in the + // ACK does not making any difference but only add complexity. + // TODO: When adding failure retries and timeout, this fork needs to be + // deleted if the owner does not receive the ACK within the timeout. + addForkOfOwner(rfd.rrefId_, rfd.forkId_); + } else { + // Note [Useful Phantom Fork ID for User to Owner Call] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // If the callee of dist.remote or dist.rpc is the owner of this RRef, the + // callee will not create a fork using this rfd.forkId_, because the owner + // will only keep one `OwnerRRef` instance and will not create any + // `UserRRef` instances. However, this rfd.forkId_ is still necessary, as + // the caller user needs to keep this `UserRRef` alive until it gets the + // ACK from the callee owner. Otherwise, the delete message could arrive + // at the owner before this dist.rpc or dist.remote call, which could + // potentially trigger the `OwnerRRef` to be deleted before running the + // user code. + addPendingChild(rfd.forkId_, rref); + } + return rfd; } -void RRefContext::delFork(const at::IValue& value) { - auto rfd = RRefForkData::fromIValue(value); - AT_ASSERT( - rfd.ownerId_ == getWorkerId(), - "RRef user should never receive delete notification."); +void RRefContext::notifyOwnerAndParentOfFork( + const ForkId& forkId, + worker_id_t parent, + const std::shared_ptr& rref) { + if (parent == rref->owner()) { + // If the parent is the owner, this fork has already been added into the + // forks_ map when the owner sends the message to the callee user. Hence, + // it is not necessary to send another RREF_CHILD_ACCEPT or + // RREF_FORK_REQUEST back to the owner. See Note [Early Fork Registration]. + return; + } + + if (rref->isOwner()) { + // See Note [Useful Phantom Fork ID for User to Owner Call] + // In this case, the owner is the caller, and it does not add the fork id + // into forks_. Because, there will be no real `UserRRef` associated with + // this fork ID. + auto fm = agent_->send( + agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); + fm->addCallback([](const Message& message) { handleException(message); }); + } else { + auto fm = agent_->send( + agent_->getWorkerInfo(rref->owner()), + RRefForkRequest(rref->rrefId(), forkId).toMessage()); + + addPendingUser(forkId, rref); + fm->addCallback([this, forkId, parent](const Message& message) { + handleException(message); + this->finishForkRequest(forkId, parent); + }); + } +} + +void RRefContext::addPendingChild( + const ForkId& forkId, + const std::shared_ptr& rref) { + // see Note [Early Fork Registration] + // If the parent is the owner, it should directly add the child UserRRef as a + // fork. + TORCH_INTERNAL_ASSERT( + !rref->isOwner(), "OwnerRRef should not have a pending child."); + std::lock_guard lock(mutex_); + TORCH_INTERNAL_ASSERT( + pendingChildren_.find(forkId) == pendingChildren_.end(), + "Inconsistent states: attempt to add the same child fork twice."); + pendingChildren_[forkId] = rref; +} + +void RRefContext::delPendingChild(const ForkId& forkId) { std::lock_guard lock(mutex_); - auto& rrefForks = forks_[rfd.rrefId_]; - AT_ASSERT( - rrefForks.find(rfd.forkId_) != rrefForks.end(), - "Attempt to delete a non-exist fork ", - rfd.forkId_); - rrefForks.erase(rfd.forkId_); - if (rrefForks.empty()) { - owners_.erase(rfd.rrefId_); - forks_.erase(rfd.rrefId_); + auto iter = pendingChildren_.find(forkId); + TORCH_INTERNAL_ASSERT( + iter != pendingChildren_.end(), + "Inconsistent states: attempt to delete a non-exist child fork."); + pendingChildren_.erase(iter); +} + +void RRefContext::addPendingUser( + const ForkId& forkId, + const std::shared_ptr& rref) { + TORCH_INTERNAL_ASSERT( + !rref->isOwner(), "Attempt to add an OwnerRRef as a pending User."); + std::lock_guard lock(mutex_); + TORCH_INTERNAL_ASSERT( + pendingUsers_.find(forkId) == pendingUsers_.end(), + "Inconsistent states: attempt to add the same UserRRef twice."); + pendingUsers_[forkId] = rref; +} + +void RRefContext::delPendingUser(const ForkId& forkId) { + std::lock_guard lock(mutex_); + auto iter = pendingUsers_.find(forkId); + TORCH_INTERNAL_ASSERT( + iter != pendingUsers_.end(), + "Inconsistent states: attempt to delete a non-exist UserRRef."); + pendingUsers_.erase(iter); +} + +void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) { + delPendingUser(forkId); + auto fm = agent_->send( + agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); + + fm->addCallback([](const Message& message) { handleException(message); }); +} + +void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) { + std::lock_guard lock(mutex_); + auto& rrefForks = forks_[rrefId]; + TORCH_INTERNAL_ASSERT( + rrefForks.find(forkId) == rrefForks.end(), + "Got fork notification twice on the same RRef ", + forkId); + rrefForks.insert(forkId); +} + +void RRefContext::delForkOfOwner(const RRefId& rrefId, const ForkId& forkId) { + std::shared_ptr deletedRRef = nullptr; + { + std::lock_guard lock(mutex_); + auto rrefIter = forks_.find(rrefId); + TORCH_INTERNAL_ASSERT( + rrefIter != forks_.end(), + "Inconsistent states, deleting a fork before the owner knows it."); + auto& rrefForks = rrefIter->second; + auto forkIter = rrefForks.find(forkId); + TORCH_INTERNAL_ASSERT( + forkIter != rrefForks.end(), + "Attempt to delete a non-exist fork ", + forkId); + + rrefForks.erase(forkId); + + if (rrefForks.empty()) { + auto ownerIter = owners_.find(rrefId); + if (ownerIter != owners_.end()) { + deletedRRef = ownerIter->second; + owners_.erase(ownerIter); + } + forks_.erase(rrefIter); + } + } + if (deletedRRef && deletedRRef->isPyObj()) { + AutoGIL ag; + deletedRRef.reset(); } } diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h index e18967416eb24..50247e8def4b7 100644 --- a/torch/csrc/distributed/rpc/rref_context.h +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -17,90 +17,103 @@ class RRefContext { public: static void initInstance(std::shared_ptr); static std::unique_ptr& getInstance(); + static void destroyInstance(); + + static void handleException(const Message& message); RRefContext(const RRefContext&) = delete; + RRefContext(RRefContext&& other) = delete; void operator=(const RRefContext&) = delete; + RRefContext& operator=(RRefContext&& other) = delete; - worker_id_t getWorkerId() const; - RRefId genRRefId(); - const std::shared_ptr& agent() const; + ~RRefContext(); - // create a new RRef - template - std::shared_ptr> createOwnerRRef(worker_id_t ownerId) { - TORCH_CHECK(ownerId == getWorkerId(), "Cannot create OwnerRRef on user."); - return getOrCreateOwnerRRef(genRRefId()); + // get the worker id of the current worker + inline worker_id_t getWorkerId() const { + return agent_->getWorkerInfo().id_; } - std::shared_ptr createUserRRef(worker_id_t ownerId) { - TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner."); - return createUserRRef(ownerId, genRRefId(), genRRefId()); + // get the worker name of the current worker + inline const std::string& getWorkerName() const { + return agent_->getWorkerInfo().name_; } - std::shared_ptr createUserRRef( - worker_id_t ownerId, - const RRefId& rrefId, - const ForkId& forkId) { - TORCH_CHECK( - ownerId != getWorkerId(), "RRef owner cannot create user RRef."); - // RRefContext does not track user RRefs, it will be destructed when there - // is no shared_ptrs pointing to it. NB: cannot use make_shared here as the - // constructor of UserRRef is private - return std::shared_ptr(new UserRRef(ownerId, rrefId, forkId)); + // generate a globally unique ID + inline GloballyUniqueId genGloballyUniqueId() { + return GloballyUniqueId(getWorkerId(), nextLocalId_++); } - // get an existing RRef or create a new one from a serialized - // ``RRefForkData``. - template - std::shared_ptr getOrCreateRRef(at::IValue&& value) { - auto rfd = RRefForkData::fromIValue(std::move(value)); - return getOrCreateRRef(rfd.ownerId_, rfd.rrefId_, rfd.forkId_); + inline const std::shared_ptr& agent() const { + return agent_; } + // create a ``UserRRef`` owned by the worker ``ownerId`` template - std::shared_ptr getOrCreateRRef( - worker_id_t ownerId, - const RRefId& rrefId, - const ForkId& forkId) { - if (ownerId == getWorkerId()) { - return getOrCreateOwnerRRef(rrefId); - } else { - return createUserRRef(ownerId, rrefId, forkId); - } - } + std::shared_ptr> createUserRRef(worker_id_t ownerId); + // Convert an RRefForkData into an RRef. This RRef could be user or owner. + // This RRef could have already existed before, or could be created in this + // method. template - std::shared_ptr> getOrCreateOwnerRRef(const RRefId& rrefId) { - std::lock_guard lock(mutex_); - const auto iter = owners_.find(rrefId); - if (iter == owners_.end()) { - // Scenario (1) the first time this owner knows about this RRef - // Scenario (2) This owner is also the creator. - // - // NB: cannot use make_shared here as the constructor of OwnerRRef is - // private. - auto rref = std::shared_ptr>( - new OwnerRRef(getWorkerId(), rrefId)); - owners_[rref->id()] = rref; - return rref; - - } else { - // Scenario (3) retrieving an existing RRef - return std::dynamic_pointer_cast>(iter->second); - } - } + std::shared_ptr getOrCreateRRef(const RRefForkData& rfd); - void addFork(const at::IValue& value); - void delFork(const at::IValue& value); + // Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new + // one. + template + std::shared_ptr> getOrCreateOwnerRRef(const RRefId& rrefId); + + // Register a fork of the ``OwnerRRef``, and inserts a shared_ptr of the + // ``OwnerRRef`` in a map to keep it alive. + void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId); + // Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the + // IValue or py::object. For the later, this method will acquire GIL. + void delForkOfOwner(const RRefId& rrefId, const ForkId& forkId); + + // Invoked when pickling an RRef to setup child/fork properly + RRefForkData prepareChildFork(const std::shared_ptr& rref); + // Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and + // send RREF_CHILD_ACCEPT to the parent. + // NB: forkId is necessary here as the rref could be an OwnerRRef + void notifyOwnerAndParentOfFork( + const ForkId& forkId, + worker_id_t parent, + const std::shared_ptr& rref); + + // When a UserRRef is forked to another worker (user or owner), it is added + // into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT + // from the child. + // NB: This is necessary for both user and owner child. As we do not have FIFO + // communication between workers, we need this strategy to make sure that all + // previously submitted rpc/remote calls are acked before sending out the + // RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too + // soon. + void addPendingChild(const ForkId& forkId, const std::shared_ptr& rref); + void delPendingChild(const ForkId& forkId); + + // When a UserRRef is created, it is added into pendingUsers_ to be held alive + // until it receives RREF_USER_ACCEPT from the owner. + void addPendingUser(const ForkId& forkId, const std::shared_ptr& rref); + void delPendingUser(const ForkId& forkId); private: RRefContext(std::shared_ptr); + template + std::shared_ptr> createUserRRef( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId); + + void finishForkRequest(const ForkId& forkId, worker_id_t parent); + + // If there is any leak on any RRef, this method will throw an error. + void checkRRefLeaks(); + static std::unique_ptr context_; static std::atomic nextLocalId_; const std::shared_ptr agent_; - std::mutex mutex_; + mutable std::mutex mutex_; // Keep OwnerRRefs alive until there is no living UserRRefs. std::unordered_map, RRefId::Hash> owners_; // Tracks known living UserRRefs of an OwnerRRef @@ -109,6 +122,26 @@ class RRefContext { std::unordered_set, RRefId::Hash> forks_; + + // The follow two maps keep UserRRefs alive by holding a shared_ptr to the + // RRef instances. A UserRRef must be added into this map if any of the + // following two conditions is ture: + // + // (1) A UserRRef has not been accepted by owner yet. + // + // It can be used or shared, but cannot be deleted, and hence kept alive + // in this map. A message of type RREF_USER_ACCEPT will remove the + // corresponding RRef from this map. + std::unordered_map, ForkId::Hash> pendingUsers_; + + // (2) A UserRRef has forked a child UserRRef which has not been accepted by + // the owner yet. + // + // In this case, this UserRRef cannot send out RREF_USER_DELETE message, + // as it could potentially trigger the OwnerRRef been deleted before the + // owner learns about the forked child. + std::unordered_map, ForkId::Hash> + pendingChildren_; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/rref_proto.cpp b/torch/csrc/distributed/rpc/rref_proto.cpp new file mode 100644 index 0000000000000..b0d39a061b302 --- /dev/null +++ b/torch/csrc/distributed/rpc/rref_proto.cpp @@ -0,0 +1,173 @@ +#include +#include + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +namespace { + +std::vector toIValues(const Message& message, MessageType type) { + TORCH_INTERNAL_ASSERT( + type == message.type(), + "Expecting message of type ", + type, + ", but got ", + message.type()); + auto payload = static_cast(message.payload().data()); + auto payload_size = message.payload().size(); + + auto value = + jit::unpickle(payload, payload_size, nullptr, &message.tensors()); + return value.toTuple()->elements(); +} + +Message fromIValues(std::vector ivalues, MessageType type) { + std::vector tensor_table; + auto payload = jit::pickle( + c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table); + return Message(std::move(payload), std::move(tensor_table), type); +} + +} // namespace + +/////////////////////////// RRefMessageBase ////////////////////////////////// + +const RRefId& RRefMessageBase::rrefId() { + return rrefId_; +} + +Message RRefMessageBase::toMessage() && { + return fromIValues({rrefId_.toIValue()}, type_); +} + +at::IValue RRefMessageBase::fromMessage( + const Message& message, + MessageType type) { + auto values = toIValues(message, type); + + TORCH_INTERNAL_ASSERT( + values.size() == 1, "ScriptUserDelete expects 1 IValue from message."); + return std::move(values.back()); +} + +/////////////////////////// ForkMessageBase ////////////////////////////////// + +const ForkId& ForkMessageBase::forkId() { + return forkId_; +} + +Message ForkMessageBase::toMessage() && { + return fromIValues({rrefId_.toIValue(), forkId_.toIValue()}, type_); +} + +std::pair ForkMessageBase::fromMessage( + const Message& message, + MessageType type) { + auto ivalues = toIValues(message, type); + + TORCH_INTERNAL_ASSERT( + ivalues.size() == 2, "ScriptUserDelete expects 2 IValue from message."); + + return std::make_pair( + RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1])); +} + +/////////////////////////// RRef Protocol ////////////////////////////////// + +std::unique_ptr ScriptRRefFetchCall::fromMessage( + const Message& message) { + return c10::guts::make_unique( + RRefId::fromIValue(RRefMessageBase::fromMessage( + message, MessageType::SCRIPT_RREF_FETCH_CALL))); +} + +std::unique_ptr PythonRRefFetchCall::fromMessage( + const Message& message) { + return c10::guts::make_unique( + RRefId::fromIValue(RRefMessageBase::fromMessage( + message, MessageType::PYTHON_RREF_FETCH_CALL))); +} + +const std::vector& RRefFetchRet::values() { + return values_; +} + +Message RRefFetchRet::toMessage() && { + std::vector ivalues = values_; + std::vector tensor_table; + auto payload = + jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table); + + return Message( + std::move(payload), std::move(tensor_table), MessageType::RREF_FETCH_RET); +} + +std::unique_ptr RRefFetchRet::fromMessage( + const Message& message) { + auto payload = static_cast(message.payload().data()); + auto payload_size = message.payload().size(); + + auto value = + jit::unpickle(payload, payload_size, nullptr, &message.tensors()); + auto values = value.toTuple()->elements(); + + return c10::guts::make_unique(std::move(values)); +} + +std::unique_ptr RRefUserDelete::fromMessage( + const Message& message) { + auto pair = + ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE); + return c10::guts::make_unique( + RRefUserDelete(pair.first, pair.second)); +} + +std::unique_ptr RemoteRet::fromMessage(const Message& message) { + auto pair = ForkMessageBase::fromMessage(message, MessageType::REMOTE_RET); + return c10::guts::make_unique(pair.first, pair.second); +} + +const ForkId& RRefChildAccept::forkId() const { + return forkId_; +} + +Message RRefChildAccept::toMessage() && { + return fromIValues({forkId_.toIValue()}, MessageType::RREF_CHILD_ACCEPT); +} + +std::unique_ptr RRefChildAccept::fromMessage( + const Message& message) { + auto values = toIValues(message, MessageType::RREF_CHILD_ACCEPT); + TORCH_INTERNAL_ASSERT(values.size() == 1, "Expect 1 IValues from message."); + + return c10::guts::make_unique( + ForkId::fromIValue(values.back())); +} + +std::unique_ptr RRefForkRequest::fromMessage( + const Message& message) { + auto pair = + ForkMessageBase::fromMessage(message, MessageType::RREF_FORK_REQUEST); + return c10::guts::make_unique(pair.first, pair.second); +} + +Message RRefAck::toMessage() && { + return Message({}, {}, MessageType::RREF_ACK); +} + +std::unique_ptr RRefAck::fromMessage(const Message& message) { + TORCH_INTERNAL_ASSERT( + message.type() == MessageType::RREF_ACK, + "Message type miss match, expect ", + MessageType::RREF_ACK, + ", but got ", + message.type()); + return c10::guts::make_unique(); +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h new file mode 100644 index 0000000000000..0a3c9ccda3191 --- /dev/null +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// Temporary solution of RRef operations. +// TODO: Remove all these messages and use rpc + registered functions instead. +class TORCH_API RRefMessageBase : public RpcCommandBase { + public: + RRefMessageBase(const RRefId& rrefId, MessageType type) + : rrefId_(rrefId), type_(type) {} + + virtual ~RRefMessageBase() override = default; + + const RRefId& rrefId(); + + virtual Message toMessage() && override; + static at::IValue fromMessage(const Message& message, MessageType type); + + protected: + const RRefId rrefId_; + const MessageType type_; +}; + +class TORCH_API ForkMessageBase : public RRefMessageBase { + public: + ForkMessageBase(const RRefId& rrefId, const ForkId& forkId, MessageType type) + : RRefMessageBase(rrefId, type), forkId_(forkId) {} + + virtual ~ForkMessageBase() override = default; + + const ForkId& forkId(); + + virtual Message toMessage() && override; + static std::pair fromMessage( + const Message& message, + MessageType type); + + protected: + const ForkId forkId_; +}; + +// UserRRef uses this message to fetch the remote RRef value from the owner. +class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase { + public: + explicit ScriptRRefFetchCall(const RRefId& rrefId) + : RRefMessageBase(rrefId, MessageType::SCRIPT_RREF_FETCH_CALL) {} + + static std::unique_ptr fromMessage( + const Message& message); +}; + +class TORCH_API PythonRRefFetchCall final : public RRefMessageBase { + public: + explicit PythonRRefFetchCall(const RRefId& rrefId) + : RRefMessageBase(rrefId, MessageType::PYTHON_RREF_FETCH_CALL) {} + + static std::unique_ptr fromMessage( + const Message& message); +}; + +// OwnerRRef uses this message to send the RRef value to a remote UserRRef +class TORCH_API RRefFetchRet final : public RpcCommandBase { + public: + explicit RRefFetchRet(std::vector values) + : values_(std::move(values)) {} + + const std::vector& values(); + + Message toMessage() && override; + static std::unique_ptr fromMessage(const Message& message); + + private: + std::vector values_; +}; + +// UserRRef (regardless it's the creator or not) uses this message to notiify +// OwnerRRef on delete. +class TORCH_API RRefUserDelete final : public ForkMessageBase { + public: + RRefUserDelete(const RRefId& rrefId, const ForkId& forkId) + : ForkMessageBase(rrefId, forkId, MessageType::RREF_USER_DELETE) {} + + static std::unique_ptr fromMessage(const Message& message); +}; + +class TORCH_API RemoteRet final : public ForkMessageBase { + public: + RemoteRet(const RRefId& rrefId, const ForkId& forkId) + : ForkMessageBase(rrefId, forkId, MessageType::REMOTE_RET) {} + + static std::unique_ptr fromMessage(const Message& message); +}; + +// A child RRef uses this message to notify its parent that the child has been +// confirmed by the owner. +class TORCH_API RRefChildAccept final : public RpcCommandBase { + public: + explicit RRefChildAccept(const ForkId& forkId) : forkId_(forkId) {} + + const ForkId& forkId() const; + + Message toMessage() && override; + static std::unique_ptr fromMessage(const Message& message); + + private: + const ForkId forkId_; +}; + +// A child RRef uses this message to send a fork request to the owner. +class TORCH_API RRefForkRequest final : public ForkMessageBase { + public: + RRefForkRequest(const RRefId& rrefId, const ForkId& forkId) + : ForkMessageBase(rrefId, forkId, MessageType::RREF_FORK_REQUEST) {} + + static std::unique_ptr fromMessage(const Message& message); +}; + +class TORCH_API RRefAck final : public RpcCommandBase { + public: + RRefAck() {} + + Message toMessage() && override; + static std::unique_ptr fromMessage(const Message& message); +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/script_call.cpp b/torch/csrc/distributed/rpc/script_call.cpp index e8ee647e0a4bb..6d9d3eac2312b 100644 --- a/torch/csrc/distributed/rpc/script_call.cpp +++ b/torch/csrc/distributed/rpc/script_call.cpp @@ -64,7 +64,7 @@ std::shared_ptr ScriptCall::fromIValues( } } -Message ScriptCall::toMessage() { +Message ScriptCall::toMessage() && { std::vector ivalues; toIValues(ivalues); @@ -76,7 +76,7 @@ Message ScriptCall::toMessage() { std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL); } -ScriptCall ScriptCall::fromMessage(const Message& message) { +std::unique_ptr ScriptCall::fromMessage(const Message& message) { auto payload = static_cast(message.payload().data()); auto payload_size = message.payload().size(); auto value = @@ -84,7 +84,7 @@ ScriptCall ScriptCall::fromMessage(const Message& message) { auto values = value.toTuple()->elements(); auto op = fromIValues(values); - return ScriptCall(op, std::move(values)); + return c10::guts::make_unique(op, std::move(values)); } std::shared_ptr ScriptCall::matchOperator( diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 4a38eed754f8e..c2954765102c1 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -15,7 +16,7 @@ using torch::jit::Operator; // A ScriptCall instance represents an invocation of a builtin operator for a // TorchScript function (not implemented yet). If it is a builtin operator, it // contains a shared ptr to the `Operator` and a list of arguments. -class TORCH_API ScriptCall { +class TORCH_API ScriptCall : public RpcCommandBase { public: ScriptCall(std::shared_ptr op, std::vector&& args); @@ -24,8 +25,8 @@ class TORCH_API ScriptCall { const std::vector& stack() const; std::vector& stackRef(); - Message toMessage(); - static ScriptCall fromMessage(const Message& message); + Message toMessage() && override; + static std::unique_ptr fromMessage(const Message& message); virtual ~ScriptCall() = default; diff --git a/torch/csrc/distributed/rpc/script_remote_call.cpp b/torch/csrc/distributed/rpc/script_remote_call.cpp index 40a8638f19eca..2af16642256c9 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.cpp +++ b/torch/csrc/distributed/rpc/script_remote_call.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace torch { @@ -8,35 +9,30 @@ namespace rpc { ScriptRemoteCall::ScriptRemoteCall( std::shared_ptr op, std::vector&& args, - at::IValue retRRefId, - at::IValue retForkId) + const RRefId& retRRefId, + const ForkId& retForkId) : ScriptCall(std::move(op), std::move(args)), - retRRefId_(std::move(retRRefId)), - retForkId_(std::move(retForkId)) {} + retRRefId_(retRRefId), + retForkId_(retForkId) {} -const at::IValue& ScriptRemoteCall::retRRefId() { - return retRRefId_; -} - -const at::IValue& ScriptRemoteCall::retForkId() { - return retForkId_; -} - -Message ScriptRemoteCall::toMessage() const { +Message ScriptRemoteCall::toMessage() && { std::vector ivalues; ScriptCall::toIValues(ivalues); - ivalues.push_back(retRRefId_); - ivalues.push_back(retForkId_); + ivalues.emplace_back(retRRefId_.toIValue()); + ivalues.emplace_back(retForkId_.toIValue()); std::vector tensor_table; - auto payload = - jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table); + auto payload = jit::pickle( + c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table); return Message( - std::move(payload), std::move(tensor_table), MessageType::REMOTE_CALL); + std::move(payload), + std::move(tensor_table), + MessageType::SCRIPT_REMOTE_CALL); } -ScriptRemoteCall ScriptRemoteCall::fromMessage(const Message& message) { +std::unique_ptr ScriptRemoteCall::fromMessage( + const Message& message) { auto payload = static_cast(message.payload().data()); auto payload_size = message.payload().size(); @@ -45,13 +41,13 @@ ScriptRemoteCall ScriptRemoteCall::fromMessage(const Message& message) { auto values = value.toTuple()->elements(); // remove the last element from values and convert it back to an RRef - auto retForkId = std::move(values.back()); + auto retForkId = RRefId::fromIValue(values.back()); values.pop_back(); - auto retRRefId = std::move(values.back()); + auto retRRefId = ForkId::fromIValue(values.back()); values.pop_back(); auto op = ScriptCall::fromIValues(values); - return ScriptRemoteCall( + return c10::guts::make_unique( op, std::move(values), std::move(retRRefId), std::move(retForkId)); } diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h index 0602884f6e40d..46bd957bbbaa3 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.h +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -11,26 +12,32 @@ namespace rpc { using torch::jit::Operator; -// A ScriptCall instance represents an invocation of a builtin operator for a -// TorchScript function (not implemented yet). If it is a builtin operator, it -// contains a shared ptr to the `Operator` and a list of arguments. +// A ScriptRemoteCall instance represents an invocation of `dist.remote` on a +// builtin operator. Currently, it does not support using RRef as arguments yet. +// Besides the operator and a vector of arguments, ScriptRemoteCall also +// caontains the RRefId and the ForkId of the return value RRef. class TORCH_API ScriptRemoteCall final : public ScriptCall { public: ScriptRemoteCall( std::shared_ptr op, std::vector&& args, - at::IValue retRRefId, - at::IValue retForkId); + const RRefId& retRRefId, + const ForkId& retForkId); - const at::IValue& retRRefId(); - const at::IValue& retForkId(); + inline const RRefId& retRRefId() const { + return retRRefId_; + } - Message toMessage() const; - static ScriptRemoteCall fromMessage(const Message& message); + inline const ForkId& retForkId() const { + return retForkId_; + } + + Message toMessage() && override; + static std::unique_ptr fromMessage(const Message& message); private: - const at::IValue retRRefId_; - const at::IValue retForkId_; + const RRefId retRRefId_; + const ForkId retForkId_; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/script_ret.cpp b/torch/csrc/distributed/rpc/script_resp.cpp similarity index 66% rename from torch/csrc/distributed/rpc/script_ret.cpp rename to torch/csrc/distributed/rpc/script_resp.cpp index 07c71518c822c..7165bccd9cf41 100644 --- a/torch/csrc/distributed/rpc/script_ret.cpp +++ b/torch/csrc/distributed/rpc/script_resp.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include @@ -13,13 +14,13 @@ using torch::jit::Unpickler; } // namespace -ScriptRet::ScriptRet(at::IValue&& value) : value_(value) {} +ScriptResp::ScriptResp(at::IValue&& value) : value_(value) {} -const at::IValue& ScriptRet::value() { +const at::IValue& ScriptResp::value() { return value_; } -Message ScriptRet::toMessage() { +Message ScriptResp::toMessage() && { std::vector tensor_table; auto payload = jit::pickle(value_, &tensor_table); ; @@ -27,12 +28,12 @@ Message ScriptRet::toMessage() { std::move(payload), std::move(tensor_table), MessageType::SCRIPT_RET); } -ScriptRet ScriptRet::fromMessage(const Message& message) { +std::unique_ptr ScriptResp::fromMessage(const Message& message) { auto payload = static_cast(message.payload().data()); auto payload_size = message.payload().size(); auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors()); - return ScriptRet(std::move(value)); + return c10::guts::make_unique(std::move(value)); } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/script_ret.h b/torch/csrc/distributed/rpc/script_resp.h similarity index 57% rename from torch/csrc/distributed/rpc/script_ret.h rename to torch/csrc/distributed/rpc/script_resp.h index 6632b9e589828..5e184cb6232c0 100644 --- a/torch/csrc/distributed/rpc/script_ret.h +++ b/torch/csrc/distributed/rpc/script_resp.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include namespace torch { @@ -8,13 +9,13 @@ namespace distributed { namespace rpc { // Return value of a builtin operator or a TorchScript function. -class TORCH_API ScriptRet final { +class TORCH_API ScriptResp final : public RpcCommandBase { public: - explicit ScriptRet(at::IValue&& values); + explicit ScriptResp(at::IValue&& values); const at::IValue& value(); - Message toMessage(); - static ScriptRet fromMessage(const Message& message); + Message toMessage() && override; + static std::unique_ptr fromMessage(const Message& message); private: const at::IValue value_; diff --git a/torch/csrc/distributed/rpc/script_rref_proto.cpp b/torch/csrc/distributed/rpc/script_rref_proto.cpp index 4b8be7d7d5518..4e0f55d8d9f1f 100644 --- a/torch/csrc/distributed/rpc/script_rref_proto.cpp +++ b/torch/csrc/distributed/rpc/script_rref_proto.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace torch { @@ -13,7 +14,7 @@ at::IValue& RRefMessageBase::valueRef() { return value_; } -Message RRefMessageBase::toMessage() const { +Message RRefMessageBase::toMessage() && { std::vector ivalues; ivalues.push_back(value_); std::vector tensor_table; @@ -35,20 +36,26 @@ at::IValue RRefMessageBase::fromMessage(const Message& message) { return std::move(values.front()); } -ScriptRRefFetchCall ScriptRRefFetchCall::fromMessage(const Message& message) { - return ScriptRRefFetchCall(RRefMessageBase::fromMessage(message)); +std::unique_ptr ScriptRRefFetchCall::fromMessage( + const Message& message) { + return c10::guts::make_unique( + RRefMessageBase::fromMessage(message)); } ScriptRRefFetchRet ScriptRRefFetchRet::fromMessage(const Message& message) { return ScriptRRefFetchRet(RRefMessageBase::fromMessage(message)); } -ScriptRRefCreate ScriptRRefCreate::fromMessage(const Message& message) { - return ScriptRRefCreate(RRefMessageBase::fromMessage(message)); +std::unique_ptr ScriptRRefCreate::fromMessage( + const Message& message) { + return c10::guts::make_unique( + RRefMessageBase::fromMessage(message)); } -ScriptRRefDelete ScriptRRefDelete::fromMessage(const Message& message) { - return ScriptRRefDelete(RRefMessageBase::fromMessage(message)); +std::unique_ptr ScriptRRefDelete::fromMessage( + const Message& message) { + return c10::guts::make_unique( + RRefMessageBase::fromMessage(message)); } } // namespace rpc diff --git a/torch/csrc/distributed/rpc/script_rref_proto.h b/torch/csrc/distributed/rpc/script_rref_proto.h index de35b7e72a251..45d6c5fc42e4a 100644 --- a/torch/csrc/distributed/rpc/script_rref_proto.h +++ b/torch/csrc/distributed/rpc/script_rref_proto.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -11,7 +12,7 @@ namespace rpc { // Temporary solution of RRef operations. // TODO: Remove all these messages and use rpc + registered functions instead. -class TORCH_API RRefMessageBase { +class TORCH_API RRefMessageBase : public RpcCommandBase { public: RRefMessageBase(at::IValue value, MessageType type) : value_(std::move(value)), type_(type) {} @@ -19,7 +20,7 @@ class TORCH_API RRefMessageBase { const at::IValue& value(); at::IValue& valueRef(); - Message toMessage() const; + Message toMessage() && override; static at::IValue fromMessage(const Message& message); private: @@ -34,7 +35,8 @@ class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase { : RRefMessageBase(std::move(rrefForkData), MessageType::RREF_FETCH_CALL) { } - static ScriptRRefFetchCall fromMessage(const Message& message); + static std::unique_ptr fromMessage( + const Message& message); }; // OwnerRRef uses this message to send the RRef value to a remote UserRRef @@ -52,7 +54,7 @@ class TORCH_API ScriptRRefCreate final : public RRefMessageBase { ScriptRRefCreate(at::IValue value) : RRefMessageBase(std::move(value), MessageType::RREF_USER_CREATE) {} - static ScriptRRefCreate fromMessage(const Message& message); + static std::unique_ptr fromMessage(const Message& message); }; // UserRRef (regardless of it's the creator or not) uses this message to notify @@ -62,7 +64,7 @@ class TORCH_API ScriptRRefDelete final : public RRefMessageBase { ScriptRRefDelete(at::IValue value) : RRefMessageBase(std::move(value), MessageType::RREF_USER_DELETE) {} - static ScriptRRefDelete fromMessage(const Message& message); + static std::unique_ptr fromMessage(const Message& message); }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/types.cpp b/torch/csrc/distributed/rpc/types.cpp index 2071a6a59ae3c..1d106301bd001 100644 --- a/torch/csrc/distributed/rpc/types.cpp +++ b/torch/csrc/distributed/rpc/types.cpp @@ -4,6 +4,17 @@ namespace torch { namespace distributed { namespace rpc { +static_assert( + std::numeric_limits::max() <= + std::numeric_limits::max(), + "The max value of local_id_t must be within the range of int64_t"); +static_assert( + std::numeric_limits::max() <= + std::numeric_limits::max(), + "The max value of worker_id_t must be within the range of int64_t"); + +/////////////////////////// GloballyUniqueId /////////////////////////// + GloballyUniqueId::GloballyUniqueId(worker_id_t createdOn, local_id_t localId) : createdOn_(createdOn), localId_(localId) {} @@ -16,8 +27,8 @@ bool GloballyUniqueId::operator!=(const GloballyUniqueId& other) const { } at::IValue GloballyUniqueId::toIValue() const { - std::vector ivalues = {(int64_t)createdOn_, (int64_t)localId_}; - return c10::ivalue::Tuple::create(std::move(ivalues)); + return c10::ivalue::Tuple::create( + {static_cast(createdOn_), static_cast(localId_)}); } GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) { @@ -28,18 +39,17 @@ GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) { "expects a GenericList of two elements, but got ", ivalues.size()); - worker_id_t createdOn = ivalues[0].toInt(); - local_id_t localId = ivalues[1].toInt(); - TORCH_CHECK( - createdOn < std::numeric_limits::max(), + ivalues[0].toInt() <= std::numeric_limits::max(), "GloballyUniqueId createdOn out of range, got ", - createdOn); + ivalues[0].toInt()); + worker_id_t createdOn = ivalues[0].toInt(); TORCH_CHECK( - localId < std::numeric_limits::max(), + ivalues[1].toInt() <= std::numeric_limits::max(), "GloballyUniqueId localId out of range, got ", - localId); + ivalues[1].toInt()); + local_id_t localId = ivalues[1].toInt(); return GloballyUniqueId(createdOn, localId); } @@ -49,6 +59,29 @@ std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) { << globalId.localId_ << ")"; } +/////////////////////////// SerializedPyObj /////////////////////////// + +std::vector SerializedPyObj::toIValues() const { + std::vector ivalues; + ivalues.reserve(tensors_.size() + 1); + for (auto& tensor : tensors_) { + ivalues.emplace_back(tensor); + } + ivalues.emplace_back(payload_); + return ivalues; +} + +SerializedPyObj SerializedPyObj::fromIValues(std::vector values) { + std::string payload = values.back().toStringRef(); + values.pop_back(); + std::vector tensors; + tensors.reserve(values.size()); + for (auto& value : values) { + tensors.emplace_back(value.toTensor()); + } + return SerializedPyObj(std::move(payload), std::move(tensors)); +} + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/types.h b/torch/csrc/distributed/rpc/types.h index 47de3d09f205a..78b869df261c3 100644 --- a/torch/csrc/distributed/rpc/types.h +++ b/torch/csrc/distributed/rpc/types.h @@ -8,11 +8,12 @@ namespace distributed { namespace rpc { using worker_id_t = int16_t; -using local_id_t = uint64_t; +using local_id_t = int64_t; -struct GloballyUniqueId final { +struct TORCH_API GloballyUniqueId final { GloballyUniqueId(worker_id_t createdOn, local_id_t localId); GloballyUniqueId(const GloballyUniqueId& other) = default; + GloballyUniqueId& operator=(const GloballyUniqueId& other) = delete; bool operator==(const GloballyUniqueId& other) const; bool operator!=(const GloballyUniqueId& other) const; @@ -32,11 +33,24 @@ struct GloballyUniqueId final { const local_id_t localId_; }; -std::ostream& operator<<(std::ostream& os, const GloballyUniqueId& globalId); +TORCH_API std::ostream& operator<<( + std::ostream& os, + const GloballyUniqueId& globalId); using RRefId = GloballyUniqueId; using ForkId = GloballyUniqueId; +struct TORCH_API SerializedPyObj final { + SerializedPyObj(std::string&& payload, std::vector&& tensors) + : payload_(std::move(payload)), tensors_(std::move(tensors)) {} + + std::vector toIValues() const; + static SerializedPyObj fromIValues(std::vector value); + + const std::string payload_; + const std::vector tensors_; +}; + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp new file mode 100644 index 0000000000000..def48925cc8db --- /dev/null +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +std::unique_ptr deserializeRequest(const Message& request) { + switch (request.type()) { + case MessageType::SCRIPT_CALL: { + return ScriptCall::fromMessage(request); + } + case MessageType::PYTHON_CALL: { + return PythonUDFCall::fromMessage(request); + } + case MessageType::SCRIPT_REMOTE_CALL: { + return ScriptRemoteCall::fromMessage(request); + } + case MessageType::PYTHON_REMOTE_CALL: { + return PythonRemoteCall::fromMessage(request); + } + case MessageType::SCRIPT_RREF_FETCH_CALL: { + return ScriptRRefFetchCall::fromMessage(request); + } + case MessageType::PYTHON_RREF_FETCH_CALL: { + return PythonRRefFetchCall::fromMessage(request); + } + case MessageType::RREF_USER_DELETE: { + return RRefUserDelete::fromMessage(request); + } + case MessageType::RREF_CHILD_ACCEPT: { + return RRefChildAccept::fromMessage(request); + } + case MessageType::RREF_FORK_REQUEST: { + return RRefForkRequest::fromMessage(request); + } + case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: { + return RpcWithAutograd::fromMessage(request); + } + default: { + TORCH_INTERNAL_ASSERT( + false, "Request type ", request.type(), " not supported."); + } + } +} + +std::unique_ptr deserializeResponse(const Message& response) { + switch (response.type()) { + case MessageType::SCRIPT_RET: { + return ScriptResp::fromMessage(response); + } + case MessageType::PYTHON_RET: { + return PythonUDFResp::fromMessage(response); + } + case MessageType::REMOTE_RET: { + return RemoteRet::fromMessage(response); + } + case MessageType::RREF_FETCH_RET: { + return RRefFetchRet::fromMessage(response); + } + case MessageType::RREF_ACK: { + return RRefAck::fromMessage(response); + } + case MessageType::EXCEPTION: { + std::string err(response.payload().begin(), response.payload().end()); + throw std::runtime_error(err); + } + case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: { + return RpcWithAutograd::fromMessage(response); + } + default: { + TORCH_INTERNAL_ASSERT( + false, "Response type ", response.type(), " not supported."); + } + } +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h new file mode 100644 index 0000000000000..d028bb3a7fa7c --- /dev/null +++ b/torch/csrc/distributed/rpc/utils.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// Given an RPC message received as a request over the wire, deserialize it into +// the appropriate 'RpcCommandBase' type. +TORCH_API std::unique_ptr deserializeRequest( + const Message& request); + +// Given an RPC message received as a response over the wire, deserialize it +// into the appropriate 'RpcCommandBase' type. +TORCH_API std::unique_ptr deserializeResponse( + const Message& response); + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/generic/StorageMethods.cpp b/torch/csrc/generic/StorageMethods.cpp index 0789d59bf75ba..7a6c614816471 100644 --- a/torch/csrc/generic/StorageMethods.cpp +++ b/torch/csrc/generic/StorageMethods.cpp @@ -105,13 +105,13 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO } #if !(defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR)) - THPByteOrder byte_order; + torch::utils::THPByteOrder byte_order; if (strcmp(byte_order_str, "native") == 0) { - byte_order = THP_nativeByteOrder(); + byte_order = torch::utils::THP_nativeByteOrder(); } else if (strcmp(byte_order_str, "big") == 0) { - byte_order = THP_BIG_ENDIAN; + byte_order = torch::utils::THP_BIG_ENDIAN; } else if (strcmp(byte_order_str, "little") == 0) { - byte_order = THP_LITTLE_ENDIAN; + byte_order = torch::utils::THP_LITTLE_ENDIAN; } else { PyErr_Format(PyExc_ValueError, "invalid byte_order '%s' (expected 'big', 'little', or 'native')", @@ -158,22 +158,30 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO // Because of ASAN checks, that are failing in the THStorage.cpp whenever // we are trying to get a value which is not 0 or 1, we have to manually // convert original values to boolean ones. - THP_decodeBoolBuffer(THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeBoolBuffer( + THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_SHORT) - THP_decodeInt16Buffer(THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeInt16Buffer( + THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_INT) - THP_decodeInt32Buffer(THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeInt32Buffer( + THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_LONG) // TODO: remove the cast - THP_decodeInt64Buffer((int64_t*) THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeInt64Buffer( + (int64_t*)THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_HALF) - THP_decodeHalfBuffer(THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeHalfBuffer( + THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_BFLOAT16) - THP_decodeBFloat16Buffer(THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeBFloat16Buffer( + THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_FLOAT) - THP_decodeFloatBuffer(THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeFloatBuffer( + THWStorage_(data)(storage), src + offset, byte_order, count); #elif defined(TH_REAL_IS_DOUBLE) - THP_decodeDoubleBuffer(THWStorage_(data)(storage), src + offset, byte_order, count); + torch::utils::THP_decodeDoubleBuffer( + THWStorage_(data)(storage), src + offset, byte_order, count); #else #error "Unknown type" #endif diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp index 7e2c8685082e6..b8e701fd96a78 100644 --- a/torch/csrc/generic/serialization.cpp +++ b/torch/csrc/generic/serialization.cpp @@ -22,15 +22,22 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd) data = (scalar_t*)cpu_data.get(); THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost)); #endif - if (THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) + if (torch::utils::THP_nativeByteOrder() == + torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) doWrite(fd, &size, sizeof(int64_t)); else { int64_t nsize; // convert big endian cpu to little endian storage - THP_encodeInt64Buffer((uint8_t*)&nsize, (const int64_t *)&size, THPByteOrder::THP_LITTLE_ENDIAN, 1); + torch::utils::THP_encodeInt64Buffer( + (uint8_t*)&nsize, + (const int64_t*)&size, + torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, + 1); doWrite(fd, &nsize, sizeof(int64_t)); } // fast track for bytes and little endian - if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) { + if (sizeof(scalar_t) == 1 || + torch::utils::THP_nativeByteOrder() == + torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) { doWrite(fd, data, sizeof(scalar_t) * size); } else { int64_t buffer_size = std::min(size, (int64_t)5000); @@ -38,19 +45,22 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd) for (int64_t i = 0; i < size; i += buffer_size) { size_t to_convert = std::min(size - i, buffer_size); if (sizeof(scalar_t) == 2) { - THP_encodeInt16Buffer((uint8_t*)le_buffer.get(), + torch::utils::THP_encodeInt16Buffer( + (uint8_t*)le_buffer.get(), (const int16_t*)data + i, - THPByteOrder::THP_LITTLE_ENDIAN, + torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (sizeof(scalar_t) == 4) { - THP_encodeInt32Buffer((uint8_t*)le_buffer.get(), + torch::utils::THP_encodeInt32Buffer( + (uint8_t*)le_buffer.get(), (const int32_t*)data + i, - THPByteOrder::THP_LITTLE_ENDIAN, + torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } else if (sizeof(scalar_t) == 8) { - THP_encodeInt64Buffer((uint8_t*)le_buffer.get(), + torch::utils::THP_encodeInt64Buffer( + (uint8_t*)le_buffer.get(), (const int64_t*)data + i, - THPByteOrder::THP_LITTLE_ENDIAN, + torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, to_convert); } doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t)); @@ -74,10 +84,12 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage) scalar_t *data; int64_t size; doRead(file, &size, sizeof(int64_t)); - if (THP_nativeByteOrder() == THPByteOrder::THP_BIG_ENDIAN) { + if (torch::utils::THP_nativeByteOrder() == + torch::utils::THPByteOrder::THP_BIG_ENDIAN) { int64_t nsize; // convert little endian storage to big endian cpu nsize = size; - THP_decodeInt64Buffer(&size, (const uint8_t*)&nsize, THP_nativeByteOrder(), 1); + torch::utils::THP_decodeInt64Buffer( + &size, (const uint8_t*)&nsize, torch::utils::THP_nativeByteOrder(), 1); } THWStoragePtr storage; if (_storage == nullptr) { @@ -97,7 +109,9 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage) #endif // fast track for bytes and little endian - if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) { + if (sizeof(scalar_t) == 1 || + torch::utils::THP_nativeByteOrder() == + torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) { doRead(file, data, sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage)); } else { int64_t buffer_size = std::min(size, (int64_t)5000); @@ -109,19 +123,22 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage) doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert); if (sizeof(scalar_t) == 2) { - THP_decodeInt16Buffer((int16_t*)data + i, + torch::utils::THP_decodeInt16Buffer( + (int16_t*)data + i, le_buffer.get(), - THP_nativeByteOrder(), + torch::utils::THP_nativeByteOrder(), to_convert); } else if (sizeof(scalar_t) == 4) { - THP_decodeInt32Buffer((int32_t*)data + i, + torch::utils::THP_decodeInt32Buffer( + (int32_t*)data + i, le_buffer.get(), - THP_nativeByteOrder(), + torch::utils::THP_nativeByteOrder(), to_convert); } else if (sizeof(scalar_t) == 8) { - THP_decodeInt64Buffer((int64_t*)data + i, + torch::utils::THP_decodeInt64Buffer( + (int64_t*)data + i, le_buffer.get(), - THP_nativeByteOrder(), + torch::utils::THP_nativeByteOrder(), to_convert); } } diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index bd2d1151ec872..f8d1daed1d6da 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -543,7 +543,9 @@ class ScriptModuleSerializer { bool bytecode_format) { C10_LOG_API_USAGE_ONCE("torch.script.save"); writeExtraFiles(module, extra_files); - // Serialize all code info. + // Serialize the model object + writeArchive("data", module.module_object()); + // Then we werialize all code info. writeCode(module.type()); // The tensor constants from the code are written to a separate archive // so loading the code does not depend on loading the data @@ -553,18 +555,19 @@ class ScriptModuleSerializer { if (bytecode_format) { writeByteCode(module); } - // finally we serialize the model - writeArchive("data", module.module_object()); } private: void writeArchive(const std::string& archive_name, const IValue& value) { std::vector data; + // Vector to capture the run-time class types during pickling the IValues + std::vector memorizedClassTypes; Pickler data_pickle( [&](const char* buf, size_t size) { data.insert(data.end(), buf, buf + size); }, - nullptr); + nullptr, + &memorizedClassTypes); data_pickle.protocol(); data_pickle.pushIValue(value); data_pickle.stop(); @@ -577,6 +580,11 @@ class ScriptModuleSerializer { std::stringstream fname; fname << archive_name << ".pkl"; writer_.writeRecord(fname.str(), data.data(), data.size()); + + // serialize all the captured run-time class types + for (const c10::ClassTypePtr& wroteType : memorizedClassTypes) { + convertNamedType(wroteType); + } } void writeExtraFiles( @@ -668,12 +676,41 @@ class ScriptModuleSerializer { for (const auto& method : methods) { const auto& func = method.function(); torch::jit::Code code(func.graph()); + // Make a copy of opnames. Some of them may be changed for mobile later. + std::vector opnames; + for (size_t i = 0; i < code.instructions().size(); ++i) { + Instruction ins = code.instructions()[i]; + if (ins.op == OP) { + auto node = code.instructions_source()[i]; + opnames.emplace_back(node->schema().operator_name()); + } + } // instructions std::vector inss; - for (const auto& ins : code.instructions()) { + for (size_t i = 0; i < code.instructions().size(); ++i) { + Instruction ins = code.instructions()[i]; TORCH_CHECK(isOpSupportedInMobile(ins.op), toString(ins.op), " is not supported in mobile module."); + if (ins.op == OP) { + if (opnames[ins.X].name == "prim::ListConstruct") { + auto node = code.instructions_source()[i]; + ins.op = OPN; + ins.N = node->inputs().size(); + ListTypePtr lt = node->output()->type()->expect(); + if (lt->getElementType() == IntType::get()) { + opnames[ins.X].overload_name = "int"; + } else if (lt->getElementType() == FloatType::get()) { + opnames[ins.X].overload_name = "float"; + } else if (lt->getElementType() == BoolType::get()) { + opnames[ins.X].overload_name = "bool"; + } else if (lt->getElementType()->isSubtypeOf(TensorType::get())) { + opnames[ins.X].overload_name = "Tensor"; + } else { + opnames[ins.X].overload_name = "generic"; + } + } + } std::vector insv{toString(ins.op), ins.X, ins.N}; inss.emplace_back(c10::ivalue::Tuple::create(std::move(insv))); } @@ -682,7 +719,7 @@ class ScriptModuleSerializer { // operators std::vector opss; - for (const auto& opname : code.opname_table()) { + for (const auto& opname : opnames) { opss.emplace_back(c10::ivalue::Tuple::create({opname.name, opname.overload_name})); } auto operators = c10::ivalue::Tuple::create(std::move(opss)); diff --git a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp index 4f2f7ebb9b045..b46adaf34ab2c 100644 --- a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp @@ -141,13 +141,16 @@ FusedKernelCUDA::FusedKernelCUDA( // Computes max blocks #ifdef __HIP_PLATFORM_HCC__ - // XXX this is a temporary hack until the occupancy API is supported in ROCm - maxBlocks_ = 16 * prop_->multiProcessorCount; + // XXX HIP function signature is not compatible yet + uint32_t max_blocks; + AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor( + &max_blocks, function_, 128, 0)); + maxBlocks_ = max_blocks; #else AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocks_, function_, 128, 0)); - maxBlocks_ *= prop_->multiProcessorCount; #endif +maxBlocks_ *= prop_->multiProcessorCount; // Resets device (end of hacked at::DeviceGuard) at::cuda::set_device(prior_device); diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index bcbd817602566..14cc8376b3f16 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -1,11 +1,13 @@ #pragma once +#include +#include + #include #include #include #include #include -#include namespace torch { namespace jit { @@ -58,7 +60,7 @@ TORCH_API void runRequiredPasses(const std::shared_ptr& g); TORCH_API void debugSetAutodiffSubgraphInlining(bool state); TORCH_API std::shared_ptr lastExecutedOptimizedGraph(); -TORCH_API bool& getProfilingMode(); +TORCH_API std::atomic &getProfilingMode(); struct TORCH_API GraphOptimizerEnabledGuard { GraphOptimizerEnabledGuard(bool state) diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 330dc58258f74..5e77c39b8a179 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -131,7 +131,7 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) { (*type.type_->getMethod("__setstate__"))({obj, input}); setGraphExecutorOptimize(true); postSetStateValidate(obj); - return std::move(obj); + return obj; } else { auto dict = std::move(input).toGenericDict(); auto obj = c10::ivalue::Object::create(type, n); diff --git a/torch/csrc/jit/import_source.cpp b/torch/csrc/jit/import_source.cpp index 26ed8fe3508e2..682b14f73d5f5 100644 --- a/torch/csrc/jit/import_source.cpp +++ b/torch/csrc/jit/import_source.cpp @@ -109,9 +109,9 @@ struct SourceImporterImpl : public Resolver, // Constants present in the model. Used to resolve "CONSTANTS.n" to the // actual value {"CONSTANTS", std::make_shared(tensor_table)}, - {"fork", std::make_shared()}, - {"annotate", std::make_shared()}, - {"uninitialized", std::make_shared()}, + {"fork", SpecialFormValue::create(prim::fork)}, + {"annotate", SpecialFormValue::create(prim::annotate)}, + {"uninitialized", SpecialFormValue::create(prim::Uninitialized)}, {"inf", std::make_shared( std::numeric_limits::infinity())}, diff --git a/torch/csrc/jit/instruction.cpp b/torch/csrc/jit/instruction.cpp index e3134d96b66d5..4dc0c5e959097 100644 --- a/torch/csrc/jit/instruction.cpp +++ b/torch/csrc/jit/instruction.cpp @@ -70,7 +70,7 @@ OpCode parseOpCode(const char *str) { bool isOpSupportedInMobile(OpCode op) { static constexpr OpCode supported_ops_in_mobile[] { - OP, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, LOOP, RET, GET_ATTR, SET_ATTR + OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, LOOP, RET, GET_ATTR, SET_ATTR }; for (auto sop : supported_ops_in_mobile) { diff --git a/torch/csrc/jit/instruction.h b/torch/csrc/jit/instruction.h index e16c434a323c9..0bc49b4bed7ab 100644 --- a/torch/csrc/jit/instruction.h +++ b/torch/csrc/jit/instruction.h @@ -19,6 +19,7 @@ namespace jit { #define FORALL_OPCODES(_) \ _(OP, "O") /* invoke operator X */ \ + _(OPN, "OI") /* invoke vararg operator X with N arguments */ \ _(LOAD, "R") /* push a value from a register X */ \ _(MOVE, "R") /* push a value from register X, clearing the register */ \ _(STOREN, "RI") /* store N values to registers [X, X+N) */ \ diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 5bee911846d27..ba8cde05e6ec9 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -339,8 +339,6 @@ struct CodeImpl { std::vector constant_table_; std::vector operator_table_; - // opname_table_ has the same order as operator_table_. - std::vector opname_table_; std::vector function_table_; std::vector type_table_; int register_size_ = 0; @@ -400,8 +398,8 @@ struct CodeImpl { return instructions_; } - const std::vector& opname_table() const { - return opname_table_; + const std::vector& instructions_source() const { + return instructions_source_; } void insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) { @@ -491,7 +489,6 @@ struct CodeImpl { emitLoadInputs(node->inputs()); insertInstruction(OP, operator_table_.size()); operator_table_.emplace_back(getOperation(node)); - opname_table_.emplace_back(node->schema().operator_name()); } void emitWait(Node* node) { @@ -848,14 +845,17 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { ActiveFrame af(frames.back()); try { while (true) { - // std::cout << "RUNNING "; - // frames.back().function->dump(std::cout, af.pc); +// std::cout << "RUNNING "; +// frames.back().function->dump(std::cout, af.pc); Instruction inst = af.instructions[af.pc]; switch (inst.op) { case OP: af.operators[inst.X](stack); ++af.pc; break; + case OPN: + AT_ERROR("OPN is currently supported in mobile mode only."); + break; case LOAD: stack.emplace_back(reg(inst.X)); ++af.pc; @@ -1132,8 +1132,8 @@ const std::vector& Code::instructions() const { return pImpl->instructions(); } -const std::vector& Code::opname_table() const { - return pImpl->opname_table(); +const std::vector& Code::instructions_source() const { + return pImpl->instructions_source(); } size_t Code::register_size() const { diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index cd683939091a5..f11c2d544967c 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -45,7 +45,7 @@ struct TORCH_API Code { size_t num_outputs() const; const std::vector& constant_table() const; const std::vector& instructions() const; - const std::vector& opname_table() const; + const std::vector& instructions_source() const; size_t register_size() const; private: diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 15edbe625e7fb..09c1f3ba55137 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -928,6 +928,7 @@ bool Node::hasSideEffects() const { case prim::CallFunction: case prim::CallMethod: case prim::BailoutTemplate: + case prim::profile: return true; } @@ -1710,6 +1711,14 @@ Node* ProfileOp::allocNewInstance(Graph* g) { return new ProfileOp(g, {nullptr}); } +TypePtr NamedValue::type() const { + if (value_) { + return value_->type(); + } else { + return ivalue_.type(); + } +} + constexpr Symbol ProfileOp::Kind; } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index bc0efa60303f4..5405b87f0f70e 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -1,28 +1,81 @@ #include "function.h" #include "interpreter.h" +#include +#include namespace torch{ namespace jit{ + +namespace { +template // int64_t, bool, double +void listConstruct(int num_inputs, Stack& stack) { + auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); + c10::List vals = + c10::impl::toList(fmap(inputs, [](const IValue& v) { return v.to(); })); + drop(stack, num_inputs); + push(stack, std::move(vals)); +} + +void tensorListConstruct(int num_inputs, Stack& stack) { + const size_t stack_size = stack.size(); + c10::List vals; + vals.reserve(num_inputs); + for (size_t i = stack_size - num_inputs; i < stack_size; ++i) { + vals.emplace_back(std::move(stack[i]).toTensor()); + } + drop(stack, num_inputs); + push(stack, std::move(vals)); +} +} + +char const * toString(OpCode op); namespace mobile { Function::Function(c10::QualifiedName name) : name_(name), code_(std::make_shared()) {} -void Function::append_instruction(OpCode op, int N, int X) { - code_->instructions_.emplace_back(op, N, X); +void Function::append_instruction(OpCode op, int X, int N) { + TORCH_CHECK(isOpSupportedInMobile(op), toString(op), + " is not supported in mobile module."); + code_->instructions_.emplace_back(op, X, N); } void Function::append_operator(const std::string& name, - const std::string& overload_name) { + const std::string& overload_name) { + // Keep the original opname in code_ code_->op_names_.emplace_back(name, overload_name); auto opname = code_->op_names_.back(); // Add "_" prefix to work around the double registration both of jit/generated // and here. TODO: remove it when we have separate build for lite interpreter. opname.name = "_" + opname.name; auto op = c10::Dispatcher::singleton().findSchema(opname); - assert(op.has_value()); + TORCH_CHECK(op.has_value(), opname.name, ".", opname.overload_name, " cannot be found."); code_->operators_.emplace_back(op); } +void Function::build_vararg_operator_table() { + for (auto& ins : code_->instructions_) { + if (ins.op == OPN) { + auto opname = code_->op_names_[ins.X]; + if (opname.name == "prim::ListConstruct") { + if (opname.overload_name == "int") { + code_->vararg_operators_.emplace_back(listConstruct); + } else if (opname.overload_name == "float") { + code_->vararg_operators_.emplace_back(listConstruct); + } else if (opname.overload_name == "bool") { + code_->vararg_operators_.emplace_back(listConstruct); + } else if (opname.overload_name == "Tensor") { + code_->vararg_operators_.emplace_back(tensorListConstruct); + } else { + AT_ERROR("Type of ListConstruct is not supported."); + } + } else { + AT_ERROR("OPN operator ", opname.name, " is not supported."); + } + ins.X = code_->vararg_operators_.size() - 1; + } + } +} + void Function::append_constant(const c10::IValue& constant) { code_->constants_.push_back(constant); } diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h index e79ce867f3e0e..28de4b453c801 100644 --- a/torch/csrc/jit/mobile/function.h +++ b/torch/csrc/jit/mobile/function.h @@ -17,8 +17,12 @@ class Function{ bool run(Stack& stack) const; const std::string& name() const; const c10::QualifiedName& qualname() const; - void append_instruction(OpCode op, int N, int X); - void append_operator(const std::string& name, const std::string& overload_name); + void append_instruction(OpCode op, int X, int N); + void append_operator(const std::string& name, + const std::string& overload_name); + void append_vararg_operator(const std::string& name, + const std::string& overload_name); + void build_vararg_operator_table(); void append_constant(const c10::IValue& constant); void set_register_size(size_t size); diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index cfce8e7d4baba..72b484e3a5ffa 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -54,20 +55,23 @@ void parseMethods(const std::vector& vals, std::shared_ptrelements(); + + auto named_ops = comps[1].toTuple()->elements(); + auto ops_name = named_ops[0].toString()->string(); + TORCH_CHECK(ops_name == "operators", + "operator is expected, but get", ops_name); + auto ops_list = named_ops[1].toTuple()->elements(); + for (const auto& ins : ins_list) { auto ins_item = ins.toTuple()->elements(); TORCH_CHECK(ins_item.size() == 3, "There should be three parts in an instruction."); OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str()); - function->append_instruction(op_code, ins_item[1].toInt(), - ins_item[2].toInt()); + int X = ins_item[1].toInt(); + int N = ins_item[2].toInt(); + function->append_instruction(op_code, X, N); } - auto named_ops = comps[1].toTuple()->elements(); - auto ops_name = named_ops[0].toString()->string(); - TORCH_CHECK(ops_name == "operators", - "operator is expected, but get", ops_name); - auto ops_list = named_ops[1].toTuple()->elements(); for (const auto& op : ops_list) { auto op_item = op.toTuple()->elements(); TORCH_CHECK(op_item.size() == 2, @@ -76,6 +80,9 @@ void parseMethods(const std::vector& vals, std::shared_ptrstring()); } + // vararg operators are stored in a separate table. + function->build_vararg_operator_table(); + auto named_consts = comps[2].toTuple()->elements(); auto consts_name = named_consts[0].toString()->string(); TORCH_CHECK(consts_name == "constants", diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index bad5c4b62fa91..1cdd33fd0b356 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -5,36 +5,45 @@ namespace torch{ namespace jit{ char const * toString(OpCode op); +std::ostream& operator<<(std::ostream& out, Instruction inst); namespace mobile { InterpreterState::InterpreterState(std::shared_ptr code) : code_(code) { registers_.resize(code_->register_size_); } -//InterpreterState::InterpreterState(Function* function) -// : function_(function) { -// registers_.resize(function->register_size()); -//} +namespace { +template // int64_t, bool, double +void listConstruct(Stack& stack, int num_inputs) { + auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); + c10::List vals = + c10::impl::toList(fmap(inputs, [](const IValue& v) { return v.to(); })); + drop(stack, num_inputs); + push(stack, std::move(vals)); +} +} bool InterpreterState::run(Stack& stack) { size_t pc = 0; while (true) { - // std::cout << "RUNNING " << pc << " " << instructions_[pc]; - // std::cout << std::endl; - // for (auto val : stack) { - // if (val.isTensor()) { - // std::cout << val.toTensor().sizes() << std::endl; - // } else { - // std::cout << val << std::endl; - // } - // } +// std::cout << "RUNNING " << pc << " " << code_->instructions_[pc]; +// std::cout << std::endl; +// for (auto val : stack) { +// if (val.isTensor()) { +// std::cout << val.toTensor().sizes() << std::endl; +// } else { +// std::cout << val << std::endl; +// } +// } Instruction inst = code_->instructions_[pc]; - TORCH_CHECK(isOpSupportedInMobile(inst.op), toString(inst.op), - " is not supported in mobile module."); switch (inst.op) { case OP: { c10::Dispatcher::singleton().callBoxed(*code_->operators_[inst.X], &stack); ++pc; } break; + case OPN: { + code_->vararg_operators_[inst.X](inst.N, stack); + ++pc; + } break; case LOAD: stack.emplace_back(reg(inst.X)); ++pc; diff --git a/torch/csrc/jit/mobile/interpreter.h b/torch/csrc/jit/mobile/interpreter.h index e976ca2d11454..5017e59dc0fa7 100644 --- a/torch/csrc/jit/mobile/interpreter.h +++ b/torch/csrc/jit/mobile/interpreter.h @@ -8,11 +8,12 @@ namespace torch{ namespace jit{ namespace mobile { using Stack = std::vector; - +using VarargFuncton = std::function; struct Code { std::vector instructions_; std::vector op_names_; std::vector> operators_; + std::vector vararg_operators_; std::vector constants_; size_t register_size_; // Aggregated output size. }; diff --git a/torch/csrc/jit/mobile/register_mobile_ops.cpp b/torch/csrc/jit/mobile/register_mobile_ops.cpp index 917b07b735e81..5d26a130ebfaf 100644 --- a/torch/csrc/jit/mobile/register_mobile_ops.cpp +++ b/torch/csrc/jit/mobile/register_mobile_ops.cpp @@ -23,5 +23,41 @@ static auto registry0 = torch::RegisterOperators().op( [](at::Tensor a, at::Scalar b, at::Scalar c) ->at::Tensor { return at::add(a, b, c); }) +).op( + "_aten::_convolution", + torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + [](at::Tensor input, at::Tensor weight, c10::optional bias, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, std::vector output_padding, + int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + return at::_convolution(input, weight, optional_to_tensor(bias), stride, padding, dilation, + transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + }) +).op( + // Dummy operator that does nothing. Used to reserve a location of an operator table. + "_prim::ListConstruct.int", + torch::RegisterOperators::options().catchAllKernel( + []() { + }) +).op( + "_prim::ListConstruct.float", + torch::RegisterOperators::options().catchAllKernel( + []() { + }) +).op( + "_prim::ListConstruct.bool", + torch::RegisterOperators::options().catchAllKernel( + []() { + }) +).op( + "_prim::ListConstruct.tensor", + torch::RegisterOperators::options().catchAllKernel( + []() { + }) +).op( + "_prim::ListConstruct.generic", + torch::RegisterOperators::options().catchAllKernel( + []() { + }) ); diff --git a/torch/csrc/jit/named_value.h b/torch/csrc/jit/named_value.h index e9a7a2a4b30c4..c27abdea37ebd 100644 --- a/torch/csrc/jit/named_value.h +++ b/torch/csrc/jit/named_value.h @@ -70,6 +70,8 @@ struct NamedValue { return *loc_; } + at::TypePtr type() const; + private: c10::optional loc_; c10::optional name_; diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 991f97478f0c4..05b8940bf776b 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -345,12 +345,17 @@ void AliasDb::analyzeImpl(Node* node) { case prim::SetAttr: return analyzeSetAttr(node); case prim::profile: - AT_ERROR("Analyzing prim::profile isn't yet implemented"); - // TODO: simply mapping inputs' aliases to outputs' - // should work but a) we should probably avoid exposing - // prim::profile to optimizations b) the alias semantics - // might be more complicated than just mapAliases - // mapAliases(node->inputs(), node->outputs()); + if (node->inputs().size() > 0) { + makePointerTo(node->output(), node->inputs().at(0)); + } + return; + case prim::BailOut: + TORCH_INTERNAL_ASSERT(node->inputs().at(0)->node()->kind() == + prim::BailoutTemplate); + makePointerTo(node->output(), node->inputs().at(1)); + return; + case prim::Guard: + makePointerTo(node->output(), node->inputs().at(0)); return; case prim::CallFunction: case prim::CallMethod: diff --git a/torch/csrc/jit/passes/bailout_graph.cpp b/torch/csrc/jit/passes/bailout_graph.cpp index 28c3543f0fbba..95eb6d5e66fa6 100644 --- a/torch/csrc/jit/passes/bailout_graph.cpp +++ b/torch/csrc/jit/passes/bailout_graph.cpp @@ -23,7 +23,7 @@ static std::unordered_set collectLoopCounts(Node *n) { it = outerNode->owningBlock(); } - return std::move(loopCounts); + return loopCounts; } struct BailOutGraphBuilderForNode { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 0eb66b0b1faa3..813ae711b967b 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -791,7 +791,7 @@ class ShapePropagator { "aten::asin(Tensor self) -> Tensor", "aten::atan(Tensor self) -> Tensor", "aten::ceil(Tensor self) -> Tensor", - "aten::clone(Tensor self) -> Tensor", + "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor", "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor", "aten::celu(Tensor self, Scalar alpha) -> Tensor", diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp index 462120e5c4d4f..0559e905f9d8f 100644 --- a/torch/csrc/jit/pickler.cpp +++ b/torch/csrc/jit/pickler.cpp @@ -105,6 +105,12 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { } else if (ivalue.isObject()) { auto obj = ivalue.toObject(); auto type = obj->type(); + if (memorized_class_types_ != nullptr) { + // Memorize every class type the Pickler encountered + // This is used to make sure we capture all the run-time types + // and serialize them properly for class/interface polymorphism + memorized_class_types_->emplace_back(type); + } pushGlobal(type->name()->prefix(), type->name()->name()); push(PickleOpCode::EMPTY_TUPLE); push(PickleOpCode::NEWOBJ); @@ -576,7 +582,7 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { auto set_type = set_schema.arguments().at(1).type(); TORCH_CHECK( - set_type->isSubtypeOf(get_type), + get_type->isSubtypeOf(set_type), "'__getstate__'s return type (", get_type->python_str(), ") does not match '__setstate__'s argument type (", diff --git a/torch/csrc/jit/pickler.h b/torch/csrc/jit/pickler.h index 1fa522a12c25c..1477bc51b5333 100644 --- a/torch/csrc/jit/pickler.h +++ b/torch/csrc/jit/pickler.h @@ -124,8 +124,11 @@ class Pickler { public: Pickler( std::function writer, - std::vector* tensor_table) - : writer_(writer), tensor_table_(tensor_table) {} + std::vector* tensor_table, + std::vector* memorized_class_types = nullptr) + : writer_(writer), + tensor_table_(tensor_table), + memorized_class_types_(memorized_class_types) {} // Push protocol onto the stack void protocol(); @@ -141,6 +144,7 @@ class Pickler { const std::vector& tensorData() { return tensor_data_; } + void pushEmptyDict(); void pushDict(const IValue& ivalue); void pushInt(int64_t value); @@ -213,6 +217,9 @@ class Pickler { // object, and we will alias it to the old object at that address. std::vector memoized_ivalues_; + // List of all the types that it wrote, inspect from the IValues it wrote. + std::vector* memorized_class_types_; + // List of tensor storages to serialize in the same binary as the pickle data // similar to ivalues, they are memoized using BINPUT std::vector tensor_data_; diff --git a/torch/csrc/jit/profiling_graph_executor_impl.cpp b/torch/csrc/jit/profiling_graph_executor_impl.cpp index 26030fca5014e..4b3c9779b610c 100644 --- a/torch/csrc/jit/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/profiling_graph_executor_impl.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -7,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -15,8 +17,8 @@ namespace torch { namespace jit { -thread_local bool profiling_mode = false; -bool& getProfilingMode() { +static std::atomic profiling_mode{false}; +std::atomic& getProfilingMode() { return profiling_mode; } @@ -58,7 +60,12 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { if (!pr_) { pr_ = ProfilingRecord::instrumentGraph(prepareGraph(graph, stack)); - profiling_plan_ = ExecutionPlan(pr_->profiled_graph_); + auto copy = pr_->graph()->copy(); + LowerGradOf(*copy); + RemoveExpands(copy); + CanonicalizeOps(copy); + EliminateDeadCode(copy); + profiling_plan_ = ExecutionPlan(copy); // fall-through } @@ -68,6 +75,12 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { // copy already has differentiableGraphs auto copy = pr_->graph()->copy(); + if (!getGraphExecutorOptimize()) { + runRequiredPasses(copy); + optimized_plan_ = ExecutionPlan(copy); + return *optimized_plan_; + } + // insert bailouts InsertGuards(copy); // get rid of autograd specific ops @@ -83,10 +96,29 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) { CanonicalizeOps(copy); EliminateRedundantGuards(copy); InsertBailOuts(copy); - // regular optimizations + // TODO: this runs specializeAutogradZero ?? + GRAPH_DUMP("After InsertBailOuts: ", copy); + runRequiredPasses(copy); ConstantPropagation(copy); runOptimization(copy); - runNondiffOptimization(copy); + if (needsGradient(copy)) { + auto diff_nodes = CreateAutodiffSubgraphs( + copy, + getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1); + for (Node *dnode : diff_nodes) { + auto diff_graph = std::move(dnode->g(attr::Subgraph)); + Gradient gradient = differentiate(diff_graph); + runOptimization(gradient.f); + // run non diff optimization on the forward graph + runNondiffOptimization(gradient.f); + packGradient(gradient, dnode); + } + InlineAutodiffSubgraphs(copy, getAutodiffSubgraphInlining() + ? autodiffSubgraphInlineThreshold + : 1); + } else { + runNondiffOptimization(copy); + } EliminateDeadCode(copy); // cache optimized_plan_ = ExecutionPlan(copy); diff --git a/torch/csrc/jit/profiling_record.cpp b/torch/csrc/jit/profiling_record.cpp index 07f0b203b18d2..dc2355ce8adef 100644 --- a/torch/csrc/jit/profiling_record.cpp +++ b/torch/csrc/jit/profiling_record.cpp @@ -5,7 +5,7 @@ namespace torch { namespace jit { ProfilingRecord::ProfilingRecord(std::shared_ptr g) - : profiled_graph_(std::move(g)), profiling_count_(3) {} + : profiled_graph_(std::move(g)), profiling_count_(1) {} ProfileOp* ProfilingRecord::createProfileNode( const std::function& fp, @@ -18,7 +18,71 @@ ProfileOp* ProfilingRecord::createProfileNode( return pn; } -void ProfilingRecord::instrumentBlock(Block* block) { +static void unprofileGraphInputs(const std::shared_ptr &graph) { + for (auto i : graph->inputs()) { + if (i->type()->isSubtypeOf(TensorType::get())) { + i->setType(unshapedType(i->type())); + } + } +} + +static void unprofileBlock(Block* start_block) { + std::vector stack; + stack.push_back(start_block); + + while (!stack.empty()) { + Block* block = stack.back(); + stack.pop_back(); + + for (auto n : block->nodes()) { + for (auto o : n->outputs()) { + if (o->type()->isSubtypeOf(TensorType::get())) { + o->setType(unshapedType(o->type())); + } + } + stack.insert(stack.end(), n->blocks().begin(), n->blocks().end()); + } + } +} + +void ProfilingRecord::insertShapeProfile(Node *n, Value *i) { + + auto pn = createProfileNode(nullptr, {i}); + auto pno = pn->addOutput(); + bool first = true; + pno->setType(TensorType::get()); + std::function shape_profiler = [this, pno, + first](Stack &stack) mutable { + IValue t; + pop(stack, t); + if (t.isTensor()) { + + if (t.toTensor().defined()) { + auto pttp = TensorType::create(t.toTensor()); + std::lock_guard lock(this->mutex_); + if (auto type = pno->type()->cast()) { + if (!first) { + pttp = pttp->merge(type); + } + pno->setType(pttp); + first = false; + } + } else { + pno->setType(TensorType::get()->withUndefined()); + } + } + + // passing t through + push(stack, t); + + }; + + pn->setCallback(shape_profiler); + pn->insertBefore(n); + n->replaceInputWith(i, pn->output()); +} + +void ProfilingRecord::instrumentBlock(Block *block) { for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { auto n = *it; for (auto i : n->inputs()) { @@ -27,32 +91,7 @@ void ProfilingRecord::instrumentBlock(Block* block) { continue; } - auto pn = createProfileNode(nullptr, {i}); - auto pno = pn->addOutput(); - bool first = true; - pno->setType(TensorType::get()); - std::function shape_profiler = - [this, pno, first](Stack& stack) mutable { - IValue t; - pop(stack, t); - if (t.isTensor()) { - auto pttp = TensorType::create(t.toTensor()); - std::lock_guard lock(this->mutex_); - if (auto type = pno->type()->cast()) { - if (!first) { - pttp = pttp->merge(type); - } - pno->setType(pttp); - first = false; - } - } - // passing t through - push(stack, t); - }; - - pn->setCallback(shape_profiler); - pn->insertBefore(n); - n->replaceInputWith(i, pn->output()); + insertShapeProfile(n, i); } for (auto b : n->blocks()) { @@ -66,8 +105,15 @@ std::unique_ptr ProfilingRecord::instrumentGraph( auto new_g = graph->copy(); auto pr = std::unique_ptr(new ProfilingRecord(new_g)); auto raw_pr = pr.get(); - + unprofileGraphInputs(new_g); + unprofileBlock(new_g->block()); pr->instrumentBlock(new_g->block()); + + for (auto i : new_g->return_node()->inputs()) { + if (i->type()->isSubtypeOf(TensorType::get())) { + pr->insertShapeProfile(new_g->return_node(), i); + } + } std::function counter = [raw_pr](Stack&) { std::lock_guard lock(raw_pr->mutex_); if (raw_pr->profiling_count_ > 0) diff --git a/torch/csrc/jit/profiling_record.h b/torch/csrc/jit/profiling_record.h index 8a15ff4c55dbe..73a332de1f59a 100644 --- a/torch/csrc/jit/profiling_record.h +++ b/torch/csrc/jit/profiling_record.h @@ -39,6 +39,7 @@ struct ProfilingRecord { const std::function& fp, at::ArrayRef inputs); void instrumentBlock(Block* block); + void insertShapeProfile(Node *n, Value *i); ProfilingRecord(std::shared_ptr g); }; diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 33e8b2604239a..362122e9c13fd 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -480,6 +480,37 @@ inline IValue toIValue( } return userObj; } + case TypeKind::InterfaceType: { + auto interfaceType = type->expect(); + // When converting an pyobj to an interface, we inspect the value + // to found the compiled TorchScript class, check if it conform + // with the interface or not, and then create a ivalue::Object + // from that class type. + py::str qualified_name = py::module::import("torch.jit") + .attr("_qualified_name")(obj.get_type()); + auto pyCu = get_python_cu(); + const auto classType = + pyCu->get_class(c10::QualifiedName(qualified_name)); + if (!classType) { + throw std::runtime_error(c10::str( + "Assigning the object ", + py::str(obj), + " to an interface fails because the value is not " + "a TorchScript compatible type, did you forget to", + "turn it into a user defined TorchScript class?")); + } + std::stringstream why_not; + if (!classType->isSubtypeOfExt(interfaceType, &why_not)) { + throw py::cast_error(c10::str( + "Object ", + py::str(obj), + " is not compatible with interface ", + interfaceType->python_str(), + "\n", + why_not.str())); + } + return toIValue(std::move(obj), classType); + } case TypeKind::NumberType: { if (THPDtype_Check(obj.ptr())) { auto dtype = reinterpret_cast(obj.ptr()); @@ -502,7 +533,6 @@ inline IValue toIValue( case TypeKind::GeneratorType: case TypeKind::VarType: case TypeKind::FutureType: - case TypeKind::InterfaceType: break; case TypeKind::FunctionType: AT_ERROR("Function Values aren't yet supported"); diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 201ec0f2bda28..921e24f93ecb0 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -722,6 +722,12 @@ void initPythonIRBindings(PyObject* module_) { .def(py::init([](const std::string& qualified_name) { return get_python_cu()->get_class(c10::QualifiedName(qualified_name)); })); + py::class_>( + m, "InterfaceType") + .def(py::init([](const std::string& qualified_name) { + return get_python_cu()->get_interface( + c10::QualifiedName(qualified_name)); + })); py::class_(m, "Use") .def_readonly("user", &Use::user) diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h index 06c3727e25622..9e1ce88a3c72b 100644 --- a/torch/csrc/jit/script/compilation_unit.h +++ b/torch/csrc/jit/script/compilation_unit.h @@ -176,6 +176,14 @@ struct TORCH_API CompilationUnit { return type->cast(); } + c10::InterfaceTypePtr get_interface(const c10::QualifiedName& name) const { + auto type = get_type(name); + if (!type) { + return nullptr; + } + return type->cast(); + } + c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const { for (const auto& cls : classes_) { if (cls->name()->qualifiedName() == name.qualifiedName()) { diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 245e20fd37e2c..531f58240e66b 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -421,7 +421,7 @@ struct Environment { if (!retval) { static std::unordered_map globals = { {"print", std::make_shared()}, - {"tuple", std::make_shared()}, + {"tuple", SpecialFormValue::create(prim::TupleConstruct)}, {"float", makeMagic( "__float__", @@ -438,8 +438,8 @@ struct Environment { makeMagic( "__str__", std::make_shared(StringType::get(), aten::str))}, - {"getattr", std::make_shared()}, - {"isinstance", std::make_shared()}, + {"getattr", SpecialFormValue::create(prim::GetAttr)}, + {"isinstance", SpecialFormValue::create(prim::isinstance)}, // todo(zach): remove when we can correctly export torch.full via ONNX // or we have implicit conversion that can convert numbers to tensors {"_to_tensor", @@ -471,9 +471,9 @@ struct Environment { {"ord", std::make_shared(aten::ord, at::nullopt)}, {"chr", std::make_shared(aten::chr, at::nullopt)}, {"bin", std::make_shared(aten::bin, at::nullopt)}, - {"range", std::make_shared(prim::range)}, - {"zip", std::make_shared(prim::zip)}, - {"enumerate", std::make_shared(prim::enumerate)}, + {"range", SpecialFormValue::create(prim::range)}, + {"zip", SpecialFormValue::create(prim::zip)}, + {"enumerate", SpecialFormValue::create(prim::enumerate)}, {"rangelist", std::make_shared(prim::rangelist, at::nullopt)}, {"sorted", @@ -1135,8 +1135,9 @@ struct to_ir { list_type = getListCompType(lc, IntType::get()); } else { throw ErrorReport(lc.range()) - << "iterator expression is expected to be a list, iterable, or range, found " - << (siv ? siv->getValue()->type()->python_str() : siv->kind()); + << "iterator expression is expected to be a list, iterable, or " + "range, found " + << sv->kind(); } // given `[x*2 for x in my_list]` this generates the following AST: @@ -2345,104 +2346,186 @@ struct to_ir { std::shared_ptr emitApplyExpr(Apply& apply, size_t n_binders) { auto sv = emitSugaredExpr(apply.callee(), 1); auto loc = apply.callee().range(); - if (auto fork_value = dynamic_cast(sv.get())) { - auto& trees = apply.inputs().tree()->trees(); - if (trees.size() < 1) { - throw ErrorReport(loc) << "Expected at least one argument to fork()"; - } - auto forked = emitSugaredExpr(Expr(trees[0]), 1); - TreeList sliced_trees(trees.begin() + 1, trees.end()); - auto inputs = getNamedValues(sliced_trees, true); - auto attributes = emitAttributes(apply.attributes()); - return emitForkExpr(loc, forked, inputs, attributes); - } else if (auto annotate_value = dynamic_cast(sv.get())) { - checkApplyNumInputs(apply, 2); - TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); - Value* expr = tryConvertToType( - apply.range(), - *graph, - type, - emitExpr(apply.inputs()[1], type), - /*allow_conversions=*/true); + if (auto special_form = dynamic_cast(sv.get())) { + return emitApplySpecialForm(special_form->form(), apply); + } + auto inputs = getNamedValues(apply.inputs(), true); + auto attributes = emitAttributes(apply.attributes()); + return sv->call(loc, method, inputs, attributes, n_binders); + } + + // this function handles expressions that look like apply statements + // but have special evaluation rules for the arguments. + // when adding a new case, only add a special form if it cannot be expressed + // using the standard SugaredValue::call function, which enforces normal + // evaluation order. + std::shared_ptr emitApplySpecialForm( + Symbol form, + Apply& apply) { + switch (form) { + case prim::fork: { + auto& trees = apply.inputs().tree()->trees(); + if (trees.size() < 1) { + throw ErrorReport(apply) + << "Expected at least one argument to fork()"; + } + auto forked = emitSugaredExpr(Expr(trees[0]), 1); + TreeList sliced_trees(trees.begin() + 1, trees.end()); + auto inputs = getNamedValues(sliced_trees, true); + auto attributes = emitAttributes(apply.attributes()); + return emitForkExpr(apply.range(), forked, inputs, attributes); + } + case prim::annotate: { + checkApplyNumInputs(apply, 2); + TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); + Value* expr = tryConvertToType( + apply.range(), + *graph, + type, + emitExpr(apply.inputs()[1], type), + /*allow_conversions=*/true); - std::stringstream why_not; - if (!expr->type()->isSubtypeOfExt(type, &why_not)) { - throw ErrorReport(apply.inputs()) - << "expected an expression of type " << type->python_str() - << " but found " << expr->type()->python_str() << "\n" - << why_not.str(); - } - - // None is a subtype of Optional[T], but we want to remember what T is, - // after annotation so that variables assigned to this None will still - // get the right type. To do this, we make a None constant that - // has the type Optional[T] - if (type->kind() == OptionalType::Kind && - expr->type()->isSubtypeOf(NoneType::get())) { - Node* none = graph->createNone(); - none->output()->setType(type); - graph->insertNode(none); - expr = none->output(); - } - - return std::make_shared(expr); - } else if (auto getattr = dynamic_cast(sv.get())) { - checkApplyNumInputs(apply, 2); - auto obj = emitSugaredExpr(apply.inputs()[0], 1); - auto selector = apply.inputs()[1]; - if (selector.kind() != TK_STRINGLITERAL) { - throw ErrorReport(loc) - << "getattr's second argument must be a string literal"; - } - const std::string& name = StringLiteral(selector).text(); - return obj->attr(apply.range(), method, name); - } else if ( - auto uninitialized_value = - dynamic_cast(sv.get())) { - checkApplyNumInputs(apply, 1); - TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); - auto out = graph->insertNode(graph->createUninitialized(type)) - ->setSourceRange(loc); - return std::make_shared(out->output()); - } else if (auto tuple_call = dynamic_cast(sv.get())) { - checkApplyNumInputs(apply, 1); - auto arg = emitSugaredExpr(apply.inputs()[0], 1); - auto inputs = arg->asTuple(apply.range(), method); - auto inp_values = fmap(inputs, [&](const SugaredValuePtr& sv) { - return sv->asValue(loc, method); - }); - return std::make_shared( - graph->insertNode(graph->createTuple(inp_values))->output()); - } else if (auto isinstance = dynamic_cast(sv.get())) { - checkApplyNumInputs(apply, 2); - auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]); - return std::make_shared(result.value()); - } else if (auto classNew = dynamic_cast(sv.get())) { - if (apply.inputs().size() != 1) { - throw ErrorReport(loc) << "Only one argument to __new__ allowed"; - } - auto arg = emitSugaredExpr(apply.inputs()[0], 1); - auto class_arg = dynamic_cast(arg.get()); - if (!class_arg) { - throw ErrorReport(loc) - << "Expected class value as argument to __new__, got " - << arg->kind() << " instead"; - } - if (class_arg->type_ != classNew->type_) { - throw ErrorReport(loc) - << "Argument to __new__() must match the class " - << "you are calling __new__() on. " - << "Got: " << class_arg->type_->python_str() - << ", expected: " << classNew->type_->python_str(); + std::stringstream why_not; + if (!expr->type()->isSubtypeOfExt(type, &why_not)) { + throw ErrorReport(apply.inputs()) + << "expected an expression of type " << type->python_str() + << " but found " << expr->type()->python_str() << "\n" + << why_not.str(); + } + + // None is a subtype of Optional[T], but we want to remember what T is, + // after annotation so that variables assigned to this None will still + // get the right type. To do this, we make a None constant that + // has the type Optional[T] + if (type->kind() == OptionalType::Kind && + expr->type()->isSubtypeOf(NoneType::get())) { + Node* none = graph->createNone(); + none->output()->setType(type); + graph->insertNode(none); + expr = none->output(); + } + + return std::make_shared(expr); } + case prim::GetAttr: { + checkApplyNumInputs(apply, 2); + auto obj = emitSugaredExpr(apply.inputs()[0], 1); + auto selector = apply.inputs()[1]; + if (selector.kind() != TK_STRINGLITERAL) { + throw ErrorReport(apply) + << "getattr's second argument must be a string literal"; + } + const std::string& name = StringLiteral(selector).text(); + return obj->attr(apply.range(), method, name); + } + case prim::Uninitialized: { + checkApplyNumInputs(apply, 1); + TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); + auto out = graph->insertNode(graph->createUninitialized(type)) + ->setSourceRange(apply.range()); + return std::make_shared(out->output()); + } + case prim::TupleConstruct: { + checkApplyNumInputs(apply, 1); + auto arg = emitSugaredExpr(apply.inputs()[0], 1); + auto inputs = arg->asTuple(apply.range(), method); + auto inp_values = fmap(inputs, [&](const SugaredValuePtr& sv) { + return sv->asValue(apply.range(), method); + }); + return std::make_shared( + graph->insertNode(graph->createTuple(inp_values))->output()); + } + case prim::isinstance: { + checkApplyNumInputs(apply, 2); + auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]); + return std::make_shared(result.value()); + } + // This represents the "__new__" method on classes + // because it takes a ClassValue as input. + // So if we see: + // Foo.__new__(Foo) + // Foo is a ClassValue, calling `attr("__new__")` will return a + // CreateObject special form. + case prim::CreateObject: { + if (apply.inputs().size() != 1) { + throw ErrorReport(apply) << "Only one argument to __new__ allowed"; + } + auto arg = emitSugaredExpr(apply.inputs()[0], 1); + auto class_arg = dynamic_cast(arg.get()); + if (!class_arg) { + throw ErrorReport(apply) + << "Expected class value as argument to __new__, got " + << arg->kind() << " instead"; + } + auto createNode = + graph->insertNode(graph->createObject(class_arg->type_)); + return std::make_shared(createNode->output()); + } + // We construct the iterable tree here using the IterableTree + // SugaredValue, The tree consists of SimpleValue, RangeValue or + // IterableValue: For SimpleValues(List, Dict, etc) or RangeValue. We will + // make them as tree leaves since we could get the loop information from + // len() and get_item(). For IterableValue like zip(), enumerate(), we can + // model them as a combination of leaves, and we emit a IterableTree value + // to record the tree information + case prim::range: { + std::vector input_vals = + getValues(apply.inputs(), /*maybe_unpack=*/true); + return std::make_shared(apply.range(), method, input_vals); + } + case prim::enumerate: { + const SourceRange& loc = apply.range(); + auto inputs = apply.inputs(); + auto input_size = apply.inputs().size(); + // enumerate(x) can be rewrite as subtrees: + // IterableTree(RangeValue(0, math.inf), SimpleValue(x)) + Value* start_index = nullptr; + if (input_size == 0) { + throw ErrorReport(loc) + << "enumerate expected at least 1 arguments, got 0"; + } - return classNew->createObject(apply.range(), method); - } else if (auto iterable = std::dynamic_pointer_cast(sv)) { - return emitIterableTree(loc, apply.inputs(), iterable); - } else { - auto inputs = getNamedValues(apply.inputs(), true); - auto attributes = emitAttributes(apply.attributes()); - return sv->call(loc, method, inputs, attributes, n_binders); + if (input_size == 2) { + start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method); + } + + if (input_size > 2) { + throw ErrorReport(loc) + << "enumerate expected at most 2 arguments, got " << input_size; + } + std::vector range_inputs; + if (start_index != nullptr) { + range_inputs.emplace_back(start_index); + } + Value* end = materializeConstant( + std::numeric_limits::max(), + *graph, + loc, + integral_constants); + range_inputs.emplace_back(end); + SugaredValuePtr range_sv = + std::make_shared(loc, method, range_inputs); + SugaredValuePtr expr_sv = emitSugaredExpr(inputs[0], 1); + return std::make_shared( + std::vector({range_sv, expr_sv})); + } + case prim::zip: { + // zip(x, y) can be rewrite as subtrees: + // IterableTree(IterableTree(x), IterableTree(y)) + auto inputs = apply.inputs(); + if (inputs.size() == 0) { + throw ErrorReport(apply) + << "zip expected at least 1 arguments, got 0"; + } + auto iterable_tree = std::make_shared(); + for (Expr expr : inputs) { + auto expr_sv = emitSugaredExpr(expr, 1); + iterable_tree->addChild(expr_sv); + } + return iterable_tree; + } + default: + TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form); } } @@ -2521,69 +2604,6 @@ struct to_ir { return graph->insertConstant(stack[0], tree->range()); } - - // We construct the iterable tree here using the IterableTree SugaredValue, - // The tree consists of SimpleValue, RangeValue or IterableValue: - // For SimpleValues(List, Dict, etc) or RangeValue. We will make them as tree - // leaves since we could get the loop information from len() and get_item(). - // For IterableValue like zip(), enumerate(), we can model them as a - // combination of leaves, and we emit a IterableTree value to record the tree - // information - SugaredValuePtr emitIterableTree( - SourceRange& loc, - const List& inputs, - const std::shared_ptr& iterable) { - std::shared_ptr iterable_tree = nullptr; - size_t input_size = inputs.size(); - - // Handling different iterable values - if (iterable->symbol_ == prim::range) { - std::vector input_vals = getValues(inputs, /*maybe_unpack=*/true); - return std::make_shared(loc, method, input_vals); - } else if (iterable->symbol_ == prim::enumerate) { - // enumerate(x) can be rewrite as subtrees: - // IterableTree(RangeValue(0, math.inf), SimpleValue(x)) - Value* start_index = nullptr; - if (input_size == 0) { - throw ErrorReport(loc) - << "enumerate expected at least 1 arguments, got 0"; - } - - if (input_size == 2) { - start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method); - } - - if (input_size > 2) { - throw ErrorReport(loc) - << "enumerate expected at most 2 arguments, got " << input_size; - } - std::vector range_inputs; - if (start_index != nullptr) { - range_inputs.emplace_back(start_index); - } - Value* end = materializeConstant( - std::numeric_limits::max(), *graph, loc, integral_constants); - range_inputs.emplace_back(end); - SugaredValuePtr range_sv = - std::make_shared(loc, method, range_inputs); - SugaredValuePtr expr_sv = emitSugaredExpr(inputs[0], 1); - iterable_tree = std::make_shared( - std::vector({range_sv, expr_sv})); - } else if (iterable->symbol_ == prim::zip) { - // zip(x, y) can be rewrite as subtrees: - // IterableTree(IterableTree(x), IterableTree(y)) - if (inputs.size() == 0) { - throw ErrorReport(loc) << "zip expected at least 1 arguments, got 0"; - } - iterable_tree = std::make_shared(); - for (Expr expr : inputs) { - auto expr_sv = emitSugaredExpr(expr, 1); - iterable_tree->addChild(expr_sv); - } - } - return iterable_tree; - } - std::shared_ptr emitForkExpr( SourceRange loc, const std::shared_ptr& forked, @@ -3238,6 +3258,7 @@ c10::QualifiedName CompilationUnit::mangle( newAtom.reserve(atom.size()); // Append the part of the name up to the end of the prefix newAtom.append(atom, 0, pos); + newAtom.append(manglePrefix); newAtom.append(std::to_string(mangleIndex_++)); atom = newAtom; return QualifiedName(atoms); @@ -3269,8 +3290,14 @@ std::unique_ptr CompilationUnit::define( auto creator = [def, _resolver, self](Function& method) { // Store the function name so that it can be referenced if there is an error // while compiling this function - ErrorReport::CallStack call( - self ? method.qualname().qualifiedName() : method.qualname().name()); + std::string call_name = method.qualname().name(); + if (self) { + auto atoms = method.qualname().atoms(); + // There should be at least a ClassName.method_name + TORCH_INTERNAL_ASSERT(atoms.size() >= 2); + call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1); + } + ErrorReport::CallStack call(call_name); to_ir(def, _resolver, self, method); }; auto name = prefix ? QualifiedName(*prefix, def.name().name()) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 01faac859d13b..cc90dde4404c6 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -685,7 +685,7 @@ void initJitScriptBindings(PyObject* module) { cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr); }); - py::class_(m, "Function", py::dynamic_attr()) + py::class_(m, "ScriptFunction", py::dynamic_attr()) .def( "__call__", [](py::args args, py::kwargs kwargs) { @@ -709,6 +709,14 @@ void initJitScriptBindings(PyObject* module) { const std::string& filename, const ExtraFilesMap& _extra_files = ExtraFilesMap()) { Module module("__torch__.PlaceholderModule"); + // [issue 27343] + // Modules have 'training' attributes by defualt, but due to + // https://github.com/pytorch/pytorch/issues/27343, functions end + // up having a training attribute when they are loaded. This adds + // a fake 'training' attribute that shouldn't be used, but prevents + // jitter on saving and loading. Once that issue is fixed this can + // be deleted. + module.register_attribute("training", BoolType::get(), true); addFunctionToModule(module, self); module.save(filename, _extra_files); }, @@ -720,6 +728,8 @@ void initJitScriptBindings(PyObject* module) { const ExtraFilesMap& _extra_files = ExtraFilesMap()) { std::ostringstream buf; Module module("__torch__.PlaceholderModule"); + // see [issue 27343] + module.register_attribute("training", BoolType::get(), true); addFunctionToModule(module, self); module.save(buf, _extra_files); return py::bytes(buf.str()); diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index 0ebc35859e8f5..e572629ca5542 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -334,13 +334,9 @@ Module Module::clone_impl( const Module& orig = s.to_module(); Module cloned = orig.clone_impl(type_remap); type_remap[orig.type()] = cloned.type(); - r.set_or_add_slot( - s.name(), - type_remap.at(s.type()), - cloned.module_object(), - s.entity_type()); + r.register_module(s.name(), cloned); } else { - r.set_or_add_slot(s.name(), s.type(), s.value(), s.entity_type()); + r.register_attribute(s.name(), s.type(), s.value(), s.is_parameter()); } } @@ -358,7 +354,7 @@ void Module::train(bool on) { if (auto slot = find_attribute("training")) { slot->setValue(on); } else { - register_attribute("training", BoolType::get(), on); + TORCH_INTERNAL_ASSERT("'training' attribute not found"); } } diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 35a5a80b733ab..dd36df7f9174b 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -144,28 +144,27 @@ struct TORCH_API Module { // register_buffer method. With this simplification, we only need to track // whether a slot is a parameter to be able to classify it. void register_buffer(const std::string& name, autograd::Variable v) { - set_or_add_slot(name, TensorType::get(), v, EntityType::ATTRIBUTE); + type()->addOrCheckAttribute(name, TensorType::get()); + module_object()->setAttr(name, v); } - void register_parameter( const std::string& name, autograd::Variable v, bool is_buffer) { - set_or_add_slot( - name, - TensorType::get(), - v, - is_buffer ? EntityType::ATTRIBUTE : EntityType::PARAMETER); + type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer); + module_object()->setAttr(name, v); } void register_attribute( const std::string& name, - const TypePtr type, - IValue ivalue) { - set_or_add_slot(name, type, ivalue, EntityType::ATTRIBUTE); + const TypePtr t, + IValue v, + bool is_param = false) { + type()->addOrCheckAttribute(name, t, is_param); + module_object()->setAttr(name, v); } void register_module(const std::string& name, const Module& module) { - set_or_add_slot( - name, module.type(), module.module_object(), EntityType::MODULE); + type()->addOrCheckAttribute(name, module.type()); + module_object()->setAttr(name, module.module_object()); } void set_parameter(const std::string& name, at::Tensor v) { @@ -260,6 +259,7 @@ struct TORCH_API Module { if (auto p = find_attribute("training")) { return p->value().toBool(); } + // We are in training mode by default return true; } @@ -396,40 +396,20 @@ struct TORCH_API Module { } return nullptr; } - void check_entity(EntityType expected, size_t slot) const { - EntityType actual = get_slot(slot).entity_type(); + + Slot get_slot(const std::string& name, EntityType etype) const { + size_t slot_idx = type()->getAttributeSlot(name); + Slot slot = get_slot(slot_idx); TORCH_CHECK( - expected == actual, + etype == slot.entity_type(), "The field '", - type()->getAttributeName(slot), + type()->getAttributeName(slot_idx), "' is a ", - toString(actual), + toString(slot.entity_type()), " but this call is" " trying to use it as a ", - toString(expected)); - } - - void set_or_add_slot( - const std::string& name, - const TypePtr& slot_type, - IValue v, - EntityType etype) { - auto slot = type()->findAttributeSlot(name); - if (!slot) { - slot = - type()->addAttribute(name, slot_type, etype == EntityType::PARAMETER); - } else { - check_entity(etype, *slot); - } - TypePtr atype = type()->getAttribute(*slot); - TORCH_CHECK(slot_type->isSubtypeOf(atype)); - module_object()->setSlot(*slot, std::move(v)); - } - - Slot get_slot(const std::string& name, EntityType etype) const { - size_t slot = type()->getAttributeSlot(name); - check_entity(etype, slot); - return get_slot(slot); + toString(etype)); + return slot; } c10::optional find_slot(const std::string& name, EntityType etype) const { diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp index 9e28bab926751..275b716bfe787 100644 --- a/torch/csrc/jit/script/python_sugared_value.cpp +++ b/torch/csrc/jit/script/python_sugared_value.cpp @@ -339,21 +339,6 @@ std::shared_ptr ModuleValue::attr( const SourceRange& loc, Function& m, const std::string& field) { - // workaround to make self.training work - // it adds a buffer 'training' to the model if one doesn't exist - // and then loads that parameter, casting it to bool - if (field == "training") { - c10::optional v = module_.find_attribute(field); - if (!v) { - bool training = py::cast(py::getattr(py_module_, "training")); - module_.register_attribute( - "training", BoolType::get(), std::move(training)); - v = module_.find_attribute(field); - } - Value* the_bool = m.graph()->insertGetAttr(self_, "training"); - return std::make_shared(the_bool); - } - if (auto v = module_.find_module(field)) { return std::make_shared( m.graph()->insertGetAttr(self_, field), @@ -503,19 +488,25 @@ std::shared_ptr BooleanDispatchValue::call( auto index = py::cast(dispatched_fn_["index"]); auto arg_name = py::str(dispatched_fn_["arg_name"]); + ErrorReport error(loc); if (index < inputs.size()) { // Dispatch flag is in arg list result = constant_as(inputs.at(index).value(graph)); + error << "Argument for boolean dispatch at position " << index + << " was not constant"; } else if (auto i = findInputWithName(arg_name, attributes)) { // Dispatch flag is in kwargs result = constant_as(attributes[*i].value(graph)); + error << "Keyword argument '" << arg_name + << "' for boolean dispatch at position was not constant"; } else { // Didn't find dispatch flag, so use default value result = py::cast(dispatched_fn_["default"]); + TORCH_INTERNAL_ASSERT(result); } if (!result) { - throw ErrorReport(loc) << "value for boolean dispatch was not constant"; + throw error; } std::shared_ptr value; @@ -578,10 +569,10 @@ std::shared_ptr toSugaredValue( } else if (py::isinstance(obj)) { return std::make_shared(obj); } else if (obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr()) { - return std::make_shared(); + return SpecialFormValue::create(prim::fork); } else if ( obj.ptr() == py::module::import("torch.jit").attr("annotate").ptr()) { - return std::make_shared(); + return SpecialFormValue::create(prim::annotate); } else if (auto callee = as_module(obj)) { throw ErrorReport(loc) << "Cannot call a ScriptModule that is not" << " a submodule of the caller"; diff --git a/torch/csrc/jit/script/slot.h b/torch/csrc/jit/script/slot.h index def1710492e9c..df3c71fb676e4 100644 --- a/torch/csrc/jit/script/slot.h +++ b/torch/csrc/jit/script/slot.h @@ -47,6 +47,9 @@ struct TORCH_API Slot { bool is_module() const { return entity_type() == EntityType::MODULE; } + bool is_parameter() const { + return container_->type()->is_parameter(offset_); + } Module to_module() const; bool operator==(const Slot& rhs) const { diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp index fb7c16a18efcf..a07dd95061076 100644 --- a/torch/csrc/jit/script/sugared_value.cpp +++ b/torch/csrc/jit/script/sugared_value.cpp @@ -87,51 +87,53 @@ std::shared_ptr SimpleValue::attr( return std::make_shared(r); } } - if (value_->type()->isSubtypeOf(NumberType::get())) { - throw ErrorReport(loc) << "Cannot call methods on numbers"; - } + + // accessing fields of named tuples if (auto tuple_type = value_->type()->cast()) { - if (!tuple_type->schema()) { - throw ErrorReport(loc) << "Getting attributes of tuples is not supported"; - } - auto attrs = tuple_type->schema()->arguments(); - for (size_t i = 0; i < attrs.size(); i++) { - if (attrs[i].name() == field) { - auto idx = m.graph()->insertConstant(IValue(static_cast(i))); - auto out_type = tuple_type->elements().at(i); - auto r = - m.graph() - ->insertNode(m.graph()->createTupleIndex(value_, idx, out_type)) - ->output(); - return std::make_shared(r); + if (tuple_type->schema()) { + auto attrs = tuple_type->schema()->arguments(); + for (size_t i = 0; i < attrs.size(); i++) { + if (attrs[i].name() == field) { + auto idx = m.graph()->insertConstant(IValue(static_cast(i))); + auto out_type = tuple_type->elements().at(i); + auto r = m.graph() + ->insertNode( + m.graph()->createTupleIndex(value_, idx, out_type)) + ->output(); + return std::make_shared(r); + } } } - throw ErrorReport(loc) << "Unknown attribute to named tuple"; - } - - if (auto classType = value_->type()->cast()) { + } else if (auto classType = value_->type()->cast()) { // This is a class, emit the proper attribute lookup if (auto method = classType->getMethod(field)) { return std::make_shared(getValue(), field); } - if (!classType->hasAttribute(field)) { - throw ErrorReport(loc) - << "Tried to access nonexistent attribute " << field - << ". Did you forget to initialize it in __init__()?"; + if (classType->hasAttribute(field)) { + auto& g = *m.graph(); + auto n = g.insertNode(g.createGetAttr(value_, field)); + return std::make_shared(n->output()); } - auto& g = *m.graph(); - auto n = g.insertNode(g.createGetAttr(value_, field)); - return std::make_shared(n->output()); - } - - if (auto iface = value_->type()->cast()) { + } else if (auto iface = value_->type()->cast()) { + // accessing methods of interfaces if (auto schema = iface->getMethod(field)) { return std::make_shared(getValue(), field); } } - return std::make_shared( - Symbol::aten(field), NamedValue(loc, "self", value_)); + // none of the more-specific cases worked, so see if this is a builtin method + if (auto builtin = BuiltinFunction::tryCreate( + Symbol::aten(field), NamedValue(loc, "self", value_))) { + return builtin; + } + + ErrorReport report(loc); + report << "Tried to access nonexistent attribute or method '" << field + << "' of type '" << value_->type()->python_str() << "'."; + if (value_->type()->kind() == ClassType::Kind) { + report << " Did you forget to initialize an attribute in __init__()?"; + } + throw report; } std::vector> SimpleValue::asTuple( @@ -474,7 +476,7 @@ std::shared_ptr ClassValue::attr( if (field != "__new__") { throw ErrorReport(loc) << "Tried to lookup unknown attribute on class"; } - return std::make_shared(type_); + return SpecialFormValue::create(prim::CreateObject); } std::shared_ptr NamedTupleConstructor::call( @@ -501,6 +503,31 @@ std::shared_ptr NamedTupleConstructor::call( return std::make_shared(self); } +std::shared_ptr BuiltinFunction::tryCreate( + Symbol symbol, + c10::optional self) { + for (const std::shared_ptr& op : getAllOperatorsFor(symbol)) { + if (!self) { + return std::make_shared(symbol, nullptr); + } + if (auto index = op->schema().argumentIndexWithName("self")) { + std::unordered_map type_env; + TypePtr formal_type = op->schema().arguments().at(*index).type(); + const MatchTypeReturn matched = + matchTypeVariables(formal_type, self->type(), type_env); + if (!matched.success()) { + continue; + } + const auto concrete_type = tryEvalTypeVariables(formal_type, type_env); + if (!concrete_type || !self->type()->isSubtypeOf(concrete_type)) { + continue; + } + return std::make_shared(symbol, self); + } + } + return nullptr; +} + } // namespace script } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/script/sugared_value.h b/torch/csrc/jit/script/sugared_value.h index 2e1fbe92cf0dd..edf543c02f2ec 100644 --- a/torch/csrc/jit/script/sugared_value.h +++ b/torch/csrc/jit/script/sugared_value.h @@ -168,6 +168,13 @@ struct TORCH_API BuiltinFunction : public SugaredValue { at::ArrayRef attributes, at::ArrayRef inputs, size_t n_binders) override; + + // try to create this builtin but if it doesn't exist or the self argument + // cannot possibly match, then return nullptr. Use in situations where it is + // not clear if it is a valid builtin + static std::shared_ptr tryCreate( + Symbol symbol, + c10::optional self); }; struct TORCH_API BuiltinModule : public SugaredValue { @@ -283,7 +290,7 @@ struct TORCH_API ClosureValue : public SugaredValue { Value* value_; }; -// defines how a method obtained from a module behaves in script +// defines how a method obtained from a module/class/interface behaves in script struct MethodValue : public SugaredValue { MethodValue(Value* self, std::string method_name) : self_(std::move(self)), method_name_(std::move(method_name)) {} @@ -386,53 +393,27 @@ struct TORCH_API MagicMethod : public SugaredValue { std::string desugared_name_; }; -// These SugaredValues have special handling in the compiler because they -// change the normal evalution order of the expression they participate in. -// They are exposed here so that the python frontend can inject them -// when it sees the equivalent thing in python - -struct TORCH_API ForkValue : public SugaredValue { - ForkValue() = default; - std::string kind() const override { - return "fork"; - } -}; -struct TORCH_API AnnotateValue : public SugaredValue { - AnnotateValue() = default; - std::string kind() const override { - return "annotate"; - } -}; - -struct TORCH_API UninitializedValue : public SugaredValue { - UninitializedValue() = default; +// things that look like function applications, but +// perform non-standard evaluation are represented +// with SpecialFormValues, e.g. +// isinstance(x, int) +// fork(fn) +// annotate(int, 3) +// The implementation of each value is handled by a case inside emitApplyExpr +struct TORCH_API SpecialFormValue : public SugaredValue { + SpecialFormValue(Symbol form) : form_(form) {} std::string kind() const override { - return "uninitialized"; + return form_.toUnqualString(); } -}; - -// matched against for special handling of getattr expressions -struct TORCH_API GetAttrValue : SugaredValue { - GetAttrValue() = default; - std::string kind() const override { - return "getattr"; + Symbol form() const { + return form_; } -}; - -// matched against for special handling of isinstance expressions -struct TORCH_API IsInstanceValue : SugaredValue { - IsInstanceValue() = default; - std::string kind() const override { - return "isinstance"; + static std::shared_ptr create(Symbol form) { + return std::make_shared(form); } -}; -// matched against for special handling of tuple() call -struct TORCH_API TupleCallValue : SugaredValue { - TupleCallValue() = default; - std::string kind() const override { - return "tuple"; - } + private: + Symbol form_; }; // matched against for special handling of range expressions @@ -455,15 +436,6 @@ struct TORCH_API RangeValue : SugaredValue { bool has_only_end_; }; -// matched against for special handling of iterables like zip(), enumerate() -struct TORCH_API IterableValue : SugaredValue { - IterableValue(Symbol symbol) : symbol_(symbol) {} - std::string kind() const override { - return "iterable"; - } - Symbol symbol_; -}; - // Specialized Tree structure to matched against for special handling // of builtin functions iterables expressions like zip(), enumerate(), etc. // zip and enumerate can be modeled as a tree of SimpleValue/RangeValue: @@ -501,28 +473,6 @@ struct TORCH_API IterableTree : SugaredValue { std::vector children_; }; -// This represents the "__new__" method on classes, which can't be a MethodValue -// because it takes a ClassValue as input. -// So if we see: -// Foo.__new__(Foo) -// Foo is a ClassValue, calling `attr("__new__")` will return a ClassNewMethod. -struct TORCH_API ClassNewMethod : public SugaredValue { - ClassNewMethod(ClassTypePtr type) : type_(type) {} - std::string kind() const override { - return "class.__new__"; - } - - std::shared_ptr createObject( - const SourceRange& loc, - Function& m) { - auto& g = *m.graph(); - auto createNode = g.insertNode(g.createObject(type_)); - return std::make_shared(createNode->output()); - } - - ClassTypePtr type_; -}; - static inline std::vector toValues( Graph& g, at::ArrayRef nvs) { diff --git a/torch/csrc/jit/source_range.h b/torch/csrc/jit/source_range.h index 5a17bd61a4dab..28ccae3c9a8b5 100644 --- a/torch/csrc/jit/source_range.h +++ b/torch/csrc/jit/source_range.h @@ -112,7 +112,7 @@ struct CAFFE2_API SourceRange { size_t size() const { return end() - start(); } - static const size_t CONTEXT = 10; + static const size_t CONTEXT = 3; void highlight(std::ostream& out) const; const std::shared_ptr& source() const { return source_; diff --git a/torch/csrc/byte_order.cpp b/torch/csrc/utils/byte_order.cpp similarity index 91% rename from torch/csrc/byte_order.cpp rename to torch/csrc/utils/byte_order.cpp index cf347e6a4e8c3..fbc2763c584a8 100644 --- a/torch/csrc/byte_order.cpp +++ b/torch/csrc/utils/byte_order.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -6,6 +6,8 @@ #include #endif +namespace { + static inline void swapBytes16(void *ptr) { uint16_t output; @@ -99,6 +101,11 @@ static inline uint64_t decodeUInt64BE(const uint8_t *data) { return output; } +} // anonymous namespace + +namespace torch { +namespace utils { + THPByteOrder THP_nativeByteOrder() { uint32_t x = 1; @@ -108,7 +115,8 @@ THPByteOrder THP_nativeByteOrder() void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len) { for (size_t i = 0; i < len; i++) { - dst[i] = (int16_t) (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); + dst[i] = (int16_t)( + order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); src += sizeof(int16_t); } } @@ -116,7 +124,8 @@ void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len) { for (size_t i = 0; i < len; i++) { - dst[i] = (int32_t) (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src)); + dst[i] = (int32_t)( + order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src)); src += sizeof(int32_t); } } @@ -124,7 +133,8 @@ void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len) { for (size_t i = 0; i < len; i++) { - dst[i] = (int64_t) (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src)); + dst[i] = (int64_t)( + order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src)); src += sizeof(int64_t); } } @@ -143,7 +153,8 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s void THP_decodeBFloat16Buffer(at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len) { for (size_t i = 0; i < len; i++) { - uint16_t x = (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); + uint16_t x = + (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src)); std::memcpy(&dst[i], &x, sizeof(dst[i])); src += sizeof(uint16_t); } @@ -232,3 +243,6 @@ void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order, } } } + +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h new file mode 100644 index 0000000000000..cc3995c72eeea --- /dev/null +++ b/torch/csrc/utils/byte_order.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace utils { + +enum THPByteOrder { + THP_LITTLE_ENDIAN = 0, + THP_BIG_ENDIAN = 1 +}; + +TORCH_API THPByteOrder THP_nativeByteOrder(); + +TORCH_API void THP_decodeInt16Buffer( + int16_t* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_decodeInt32Buffer( + int32_t* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_decodeInt64Buffer( + int64_t* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_decodeHalfBuffer( + THHalf* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_decodeFloatBuffer( + float* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_decodeDoubleBuffer( + double* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_decodeBoolBuffer( + bool* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_decodeBFloat16Buffer( + at::BFloat16* dst, + const uint8_t* src, + THPByteOrder order, + size_t len); + +TORCH_API void THP_encodeInt16Buffer( + uint8_t* dst, + const int16_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_encodeInt32Buffer( + uint8_t* dst, + const int32_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_encodeInt64Buffer( + uint8_t* dst, + const int64_t* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_encodeFloatBuffer( + uint8_t* dst, + const float* src, + THPByteOrder order, + size_t len); +TORCH_API void THP_encodeDoubleBuffer( + uint8_t* dst, + const double* src, + THPByteOrder order, + size_t len); + +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 08124c1a3411e..c2982468cf703 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -671,19 +671,4 @@ Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* throw std::runtime_error("new_ones(): invalid arguments"); } -Tensor new_zeros(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { - static PythonArgParser parser({ - "new_zeros(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", - }, /*traceable=*/true); - - ParsedArgs<4> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - if (r.idx == 0) { - const auto actual_type_id = typeIdWithDefault(r, 2, type_id); - const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type); - return dispatch_zeros(actual_type_id, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); - } - throw std::runtime_error("new_zeros(): invalid arguments"); -} - }} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_new.h b/torch/csrc/utils/tensor_new.h index 76a9386986086..1fa7d75c131ac 100644 --- a/torch/csrc/utils/tensor_new.h +++ b/torch/csrc/utils/tensor_new.h @@ -18,6 +18,5 @@ at::Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, Py at::Tensor as_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); at::Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); at::Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor new_zeros(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); }} // namespace torch::utils diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 987b529cdee15..10f745b6b13cc 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -54,4 +54,4 @@ def init_model_parallel(self_name, """ _init_rpc(backend, self_name, self_rank, init_method, num_send_recv_threads) from .rpc_api import _agent - autograd._init(_agent.get_worker_id().id) + autograd._init(_agent.get_worker_info().id) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index f4b455b9daf7c..098d9d5b554a5 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -129,7 +129,8 @@ class GroupMember(object): _default_pg_init_method = None # Default process group wide timeout, if applicable. -# This currently only applies to the gloo backend. To make an attempt at +# This only applies to the gloo and nccl backends +# (only if NCCL_BLOCKING_WAIT is set to 1). To make an attempt at # backwards compatibility with THD, we use an extraordinarily high default # timeout, given that THD did not have timeouts. _default_pg_timeout = timedelta(minutes=30) @@ -346,7 +347,9 @@ def init_process_group(backend, Mutually exclusive with ``init_method``. timeout (timedelta, optional): Timeout for operations executed against the process group. Default value equals 30 minutes. - This is only applicable for the ``gloo`` backend. + This is applicable for the ``gloo`` backend. For ``nccl``, this is + applicable only if the environment variable ``NCCL_BLOCKING_WAIT`` + is set to 1. group_name (str, optional, deprecated): Group name. To enable ``backend == Backend.MPI``, PyTorch needs to built from source @@ -485,7 +488,8 @@ def _new_process_group_helper(world_size, pg = ProcessGroupNCCL( prefix_store, rank, - world_size) + world_size, + timeout) _pg_map[pg] = (Backend.NCCL, store) _pg_names[pg] = group_name else: diff --git a/torch/distributed/internal_rpc_utils.py b/torch/distributed/internal_rpc_utils.py index 6475366e471b0..99834a929701e 100644 --- a/torch/distributed/internal_rpc_utils.py +++ b/torch/distributed/internal_rpc_utils.py @@ -4,7 +4,6 @@ import copyreg import io import pickle -import six import threading import traceback @@ -26,10 +25,10 @@ class _InternalRPCPickler: e.g. attach tensor to distributed autograd graph in C++ """ def __init__(self): - # python2 does not have dispatch_table, add "if six.PY3" condition, + # python2 does not have dispatch_table, add "if torch._six.PY3" condition, # as _InternalRPCPickler still got build in python2 even # we skipped python 2 tests for rpc_test - if six.PY3: + if torch._six.PY3: self._dispatch_table = copyreg.dispatch_table.copy() self._dispatch_table[torch.Tensor] = self._tensor_reducer @@ -100,6 +99,8 @@ def deserialize(self, binary_data, tensor_table): # Create _internal_rpc_pickler only once to initialize _dispatch_table only once _internal_rpc_pickler = _InternalRPCPickler() +def serialize(obj): + return _internal_rpc_pickler.serialize(obj) def run_python_udf_internal(pickled_python_udf, tensors): r""" @@ -114,7 +115,8 @@ def run_python_udf_internal(pickled_python_udf, tensors): # except str = exception info + traceback string except_str = "{}\n{}".format(repr(e), traceback.format_exc()) result = RemoteException(except_str) - return _internal_rpc_pickler.serialize(result) + # return _internal_rpc_pickler.serialize(result) + return result def load_python_udf_result_internal(pickled_python_result, tensors): diff --git a/torch/distributed/rpc_api.py b/torch/distributed/rpc_api.py index 8b033229a76fb..6c391490ae074 100644 --- a/torch/distributed/rpc_api.py +++ b/torch/distributed/rpc_api.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 -from . import invoke_rpc_builtin, invoke_rpc_python_udf, invoke_remote_builtin -from . import init_rref_context +from . import invoke_rpc_builtin, invoke_rpc_python_udf +from . import invoke_remote_builtin, invoke_remote_python_udf +from . import _init_rref_context, _destroy_rref_context from . import ProcessGroupAgent -from . import WorkerId -from .rpc_backend_registry import is_rpc_backend_registered, init_rpc_backend +from . import WorkerInfo from .internal_rpc_utils import _internal_rpc_pickler, PythonUDF +from .rpc_backend_registry import is_rpc_backend_registered, init_rpc_backend import functools import sys @@ -38,6 +39,7 @@ def join_rpc(): if _agent: _agent.join() _agent = None + _destroy_rref_context() @_require_initialized @@ -78,40 +80,41 @@ def _init_rpc(backend=RpcBackend.PROCESS_GROUP, self_rank, group.rank())) # TODO: add try-except and destroy _agent in all processes if any fails. _agent = ProcessGroupAgent(self_name, group, num_send_recv_threads) - init_rref_context(_agent) + _init_rref_context(_agent) elif is_rpc_backend_registered(backend): _agent = init_rpc_backend( backend, self_rank=self_rank, self_name=self_name, - init_method=init_method, + init_method=init_method ) - init_rref_context(_agent) + _init_rref_context(_agent) else: raise RuntimeError("Unrecognized RPC backend ", backend) @_require_initialized -def get_worker_id(worker_name=None): +def get_worker_info(worker_name=None): r""" - Get worker id of a given worker name. Use this worker id to avoid passing - an expensive string to ``rpc`` on every invocation. + Get WorkerInfo of a given worker name. Use this WorkerInfo to avoid passing + an expensive string to ``rpc`` on every invocation. The WorkerInfo contains + the name of the worker and the id of the worker. Arguments: worker_name (str): the string name of a worker. If ``None``, return the the id of the current worker. (default ``None``) """ if worker_name: - return _agent.get_worker_id(worker_name) + return _agent.get_worker_info(worker_name) else: - return _agent.get_worker_id() + return _agent.get_worker_info() -def _to_worker_id(name_or_id): - if isinstance(name_or_id, WorkerId): +def _to_worker_info(name_or_id): + if isinstance(name_or_id, WorkerInfo): return name_or_id elif isinstance(name_or_id, str): - return get_worker_id(name_or_id) + return get_worker_info(name_or_id) else: raise ValueError("Unsupported RPC worker ID type {}".format(name_or_id)) @@ -142,7 +145,7 @@ def remote(to, func, args=None, kwargs=None): >>> import torch.distributed as dist >>> dist.init_process_group(backend='gloo', rank=0, world_size=2) >>> dist.init_rpc("worker0") - >>> worker1 = dist.get_worker_id("worker1") + >>> worker1 = dist.get_worker_info("worker1") >>> rref1 = dist.remote(worker1, torch.add, args=(torch.ones(2), 3)) >>> rref2 = dist.remote(worker1, torch.add, args=(torch.ones(2), 1)) >>> x = rref1.to_here() + rref2.to_here() @@ -159,8 +162,15 @@ def remote(to, func, args=None, kwargs=None): args = args if args else () kwargs = kwargs if kwargs else {} - return invoke_remote_builtin( - _agent, _to_worker_id(to), qualified_name, *args, **kwargs) + info = _to_worker_info(to) + if qualified_name is not None: + return invoke_remote_builtin( + _agent, info, qualified_name, *args, **kwargs) + else: + (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize( + PythonUDF(func, args, kwargs)) + return invoke_remote_python_udf( + _agent, info, pickled_python_udf, tensors) def _invoke_rpc(to, func, args=None, kwargs=None): @@ -172,15 +182,16 @@ def _invoke_rpc(to, func, args=None, kwargs=None): args = args if args else () kwargs = kwargs if kwargs else {} + info = _to_worker_info(to) if qualified_name is not None: fut = invoke_rpc_builtin( - _agent, _to_worker_id(to), qualified_name, *args, **kwargs + _agent, info, qualified_name, *args, **kwargs ) else: (pickled_python_udf, tensors) = _internal_rpc_pickler.serialize( PythonUDF(func, args, kwargs)) fut = invoke_rpc_python_udf( - _agent, _to_worker_id(to), pickled_python_udf, tensors) + _agent, info, pickled_python_udf, tensors) return fut @@ -314,7 +325,7 @@ def rpc(to, func, args=None, kwargs=None, async_call=False): >>> import torch.distributed as dist >>> dist.init_process_group(backend='gloo', rank=0, world_size=2) >>> dist.init_model_parallel("worker0") - >>> worker1 = dist.get_worker_id("worker1") + >>> worker1 = dist.get_worker_info("worker1") >>> fut1 = dist.rpc(worker1, torch.add, args=(torch.ones(2), 3), async_call=True) >>> fut2 = dist.rpc(worker1, min, args=(1, 2), async_call=True) >>> result = fut1.wait() + fut2.wait() @@ -330,6 +341,7 @@ def rpc(to, func, args=None, kwargs=None, async_call=False): """dist.rpc is deprecated. Use dist.rpc_async for asynchronous calls or dist.rpc_sync for synchronous calls instead.""" ) + if async_call: return rpc_async(to, func, args, kwargs) else: diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index f69cb858f4afa..72ee285e7bee8 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -14,6 +14,7 @@ from torch._six import PY2, PY37, with_metaclass, string_classes from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \ _list_with_default +from torch.utils import set_module import collections import contextlib @@ -65,6 +66,7 @@ def _parse_env(name, default, true_message, false_message): _python_cu = torch._C.CompilationUnit() Future = torch._C.Future +set_module(Future, "torch.jit") _fork = torch._C.fork _wait = torch._C.wait @@ -724,14 +726,14 @@ def trace(func, _module_class=None, _compilation_unit=_python_cu): """ - Trace a function and return an executable ``ScriptModule`` or ``torch._C.Function`` + Trace a function and return an executable ``ScriptModule`` or ``torch.jit.ScriptFunction`` that will be optimized using just-in-time compilation. Using ``torch.jit.trace`` and :func:`torch.jit.trace_module`, you can turn an existing module or Python - function into a TorchScript ``torch._C.Function`` or ``ScriptModule``. You must provide example inputs, + function into a TorchScript ``torch.jit.ScriptFunction`` or ``ScriptModule``. You must provide example inputs, and we run the function, recording the operations performed on all the tensors. - * The resulting recording of a standalone function produces ``torch._C.Function``. + * The resulting recording of a standalone function produces ``torch.jit.ScriptFunction``. * The resulting recording of ``forward`` function of ``nn.Module`` or ``nn.Module`` produces ``ScriptModule``. This module also contains any parameters that the original @@ -799,7 +801,7 @@ def trace(func, a ``ScriptModule`` object with a single ``forward()`` method containing the traced code. The returned ``ScriptModule`` will have the same set of sub-modules and parameters as the original ``nn.Module``. - If ``callable`` is a standalone function, ``trace`` returns ``torch._C.Function`` + If ``callable`` is a standalone function, ``trace`` returns ``torch.jit.ScriptFunction`` Example (tracing a function): @@ -1080,7 +1082,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): r""" Scripting a function or ``nn.Module`` will inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a ``ScriptModule`` or - ``torch._C.Function``. TorchScript itself is a subset of the Python language, so not all + ``torch.jit.ScriptFunction``. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations. For a complete guide, see the `TorchScript Language Reference`_. @@ -1089,7 +1091,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): ``@torch.jit.script`` for `TorchScript Classes `_ and functions. **Scripting a function** - The ``@torch.jit.script`` decorator will construct a ``torch._C.Function`` + The ``@torch.jit.script`` decorator will construct a ``torch.jit.ScriptFunction`` by compiling the body of the function. Example (scripting a function): @@ -1255,6 +1257,7 @@ def interface(obj): ast = get_jit_class_def(obj, obj.__name__) rcb = _jit_internal.createResolutionCallback(1) torch._C._jit_script_interface_compile(qualified_name, ast, rcb) + obj.__torch_script_interface__ = True return obj ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method')) @@ -1512,8 +1515,15 @@ def __init__(self, optimize=None, _qualified_name=None, _compilation_unit=None, else: self.__dict__['_c'] = torch._C.ScriptModule(_qualified_name, _compilation_unit, True) - Module._Module__construct(self) - Module.__setattr__(self, "training", True) + training = None + if self._c._has_attribute('training'): + training = self._c._get_attribute('training') + super(ScriptModule, self).__init__() + if training is not None: + self.training = training + self._c._register_attribute('training', torch._C.BoolType.get(), training) + elif not self._c._has_attribute('training'): + self._c._register_attribute('training', torch._C.BoolType.get(), self.training) self._parameters = OrderedParameterDict(self._c) self._buffers = OrderedBufferDict(self._c) @@ -1579,11 +1589,6 @@ def __getattr__(self, attr): def __setattr__(self, attr, value): if attr not in self._constants_set: - if attr == 'training': - if self._c._has_attribute('training'): - self.__dict__['training'] = value - self._c._set_attribute('training', value) - return if isinstance(value, Attribute): the_type = torch.jit.annotations.ann_to_type(value.type) try: @@ -1592,6 +1597,8 @@ def __setattr__(self, attr, value): raise RuntimeError("Could not register attribute '{}' of type '{}' for a value of type '{}'" .format(attr, value.type, type(value.value))) return + if self._c._has_attribute(attr): + self._c._set_attribute(attr, value) return super(ScriptModule, self).__setattr__(attr, value) if hasattr(self, attr): @@ -2058,6 +2065,10 @@ def _check_directly_compile_overloaded(obj): # torch.jit.Error Error = torch._C.JITException +set_module(Error, "torch.jit") +# This is not perfect but works in common cases +Error.__name__ = "Error" +Error.__qualname__ = "Error" def _get_named_tuple_properties(obj): assert issubclass(obj, tuple) and hasattr(obj, '_fields') @@ -2105,8 +2116,9 @@ def _graph_for(self, *args, **kwargs): return last_executed_optimized_graph() torch._C.ScriptMethod.graph_for = _graph_for -torch._C.Function.graph_for = _graph_for -Function = torch._C.Function +torch._C.ScriptFunction.graph_for = _graph_for +ScriptFunction = torch._C.ScriptFunction +set_module(ScriptFunction, "torch.jit") if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed") diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index e888137e192ec..68d3b3a9f5bd1 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -58,10 +58,6 @@ def copy_to_script_module(original, stubs): # the type if possible class_annotations = getattr(original, '__annotations__', {}) for name in dir(original): - if name in ("training", "__dict__"): - # TODO: removing this skip should let us remove the code to add training as an - # attribute in python_sugared_value.cpp - continue if hasattr(script_module, name): # Don't re-copy properties continue @@ -70,6 +66,7 @@ def copy_to_script_module(original, stubs): the_type = torch.jit.annotations.ann_to_type(class_annotations[name]) else: the_type = torch._C._jit_try_infer_type(item) + if the_type is not None: try: script_module._c._register_attribute(name, the_type, item) diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index 0a5283ed4cf0f..8096f5150e6c1 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -7,7 +7,7 @@ BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \ is_optional, _qualified_name from torch._C import TensorType, TupleType, FloatType, IntType, \ - ListType, StringType, DictType, BoolType, OptionalType, ClassType + ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType from textwrap import dedent from torch._six import builtins from torch._utils_internal import get_source_lines_and_file @@ -249,6 +249,8 @@ def ann_to_type(ann, resolver=None): return BoolType.get() elif hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) + elif hasattr(ann, "__torch_script_interface__"): + return InterfaceType(_qualified_name(ann)) elif resolver is not None: # Maybe resolve a NamedTuple to a Tuple Type rcb, loc = resolver diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt index 2557d233cd849..497bdcb5d29a3 100644 --- a/torch/lib/c10d/CMakeLists.txt +++ b/torch/lib/c10d/CMakeLists.txt @@ -16,6 +16,10 @@ else() message(STATUS "Building c10d without CUDA/ROCm support") endif() +if(USE_TBB) +include_directories(${TBB_ROOT_DIR}/include) +endif() + if(USE_GLOO) option(USE_C10D_GLOO "USE C10D GLOO" ON) endif() @@ -52,7 +56,7 @@ set(C10D_SRCS set(C10D_LIBS torch) if(USE_C10D_NCCL) - list(APPEND C10D_SRCS ProcessGroupNCCL.cpp) + list(APPEND C10D_SRCS ProcessGroupNCCL.cpp NCCLUtils.cpp) list(APPEND C10D_LIBS __caffe2_nccl) endif() diff --git a/torch/lib/c10d/NCCLUtils.cpp b/torch/lib/c10d/NCCLUtils.cpp new file mode 100644 index 0000000000000..c00d383ae69d0 --- /dev/null +++ b/torch/lib/c10d/NCCLUtils.cpp @@ -0,0 +1,31 @@ +#include +#include + +namespace c10d { + +std::string getNcclVersion() { + static std::once_flag ncclGetVersionFlag; + static std::string versionString; + + std::call_once(ncclGetVersionFlag, []() { + int version; + ncclResult_t status = ncclGetVersion(&version); + if (status != ncclSuccess) { + versionString = "Unknown NCCL version"; + } + auto ncclMajor = version / 1000; + auto ncclMinor = (version % 1000) / 100; + auto ncclPatch = version % (ncclMajor * 1000 + ncclMinor * 100); + versionString = std::to_string(ncclMajor) + "." + + std::to_string(ncclMinor) + "." + std::to_string(ncclPatch); + }); + + return versionString; +} + +std::string ncclGetErrorWithVersion(ncclResult_t error) { + return std::string(ncclGetErrorString(error)) + ", NCCL version " + + getNcclVersion(); +} + +} // namespace c10d diff --git a/torch/lib/c10d/NCCLUtils.hpp b/torch/lib/c10d/NCCLUtils.hpp index 4ab1f5757691f..395b8cab49d62 100644 --- a/torch/lib/c10d/NCCLUtils.hpp +++ b/torch/lib/c10d/NCCLUtils.hpp @@ -1,28 +1,32 @@ #pragma once -#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ - (NCCL_MINOR < 4) -#error "Need NCCL version 2.4+" -#elif defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) -#error "Need NCCL version 2.4+" -#endif - #include #include -#define C10D_NCCL_CHECK(cmd) \ - do { \ - ncclResult_t error = cmd; \ - if (error != ncclSuccess) { \ - std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ - std::to_string(__LINE__) + ", " + \ - std::string(ncclGetErrorString(error)); \ - throw std::runtime_error(err); \ - } \ +// Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort() +// and ncclCommGetAsyncError() are not supported in earlier versions. +#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ + (NCCL_MINOR >= 4) +#define ENABLE_NCCL_ERROR_CHECKING +#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) +#define ENABLE_NCCL_ERROR_CHECKING +#endif + +#define C10D_NCCL_CHECK(cmd) \ + do { \ + ncclResult_t error = cmd; \ + if (error != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(error); \ + throw std::runtime_error(err); \ + } \ } while (0) namespace c10d { +std::string getNcclVersion(); +std::string ncclGetErrorWithVersion(ncclResult_t error); + // RAII wrapper for NCCL communicator class NCCLComm { public: @@ -33,9 +37,14 @@ class NCCLComm { ~NCCLComm() noexcept(false) { if (ncclComm_ && !aborted_) { - // Use ncclCommAbort instead of ncclCommDestroy here since ncclCommDestroy - // could block forever waiting for work to complete on the communicator. +#ifdef ENABLE_NCCL_ERROR_CHECKING + // Use ncclCommAbort instead of ncclCommDestroy here since + // ncclCommDestroy could block forever waiting for work to complete on + // the communicator. ncclCommAbort(); +#else + C10D_NCCL_CHECK(::ncclCommDestroy(ncclComm_)); +#endif } } @@ -76,6 +85,7 @@ class NCCLComm { } void ncclCommAbort() { +#ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_) { // Should not abort twice. return; @@ -89,6 +99,10 @@ class NCCLComm { if (ncclAsyncErr_ == ncclSuccess) { ncclAsyncErr_ = ncclSystemError; } +#else + // This is a NOOP, if error checks are disabled. + return; +#endif } bool isAborted() const { @@ -96,11 +110,16 @@ class NCCLComm { } ncclResult_t checkForNcclError() { +#ifdef ENABLE_NCCL_ERROR_CHECKING if (ncclAsyncErr_ != ncclSuccess) { return ncclAsyncErr_; } C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_)); return ncclAsyncErr_; +#else + // Always return success, if error checks are disabled. + return ncclSuccess; +#endif } protected: diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 85ea14c14e5c2..f26b42b57c978 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -246,14 +246,18 @@ ProcessGroupNCCL::ProcessGroupNCCL( std::string(NCCL_BLOCKING_WAIT)); } +#ifdef ENABLE_NCCL_ERROR_CHECKING ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); +#endif } ProcessGroupNCCL::~ProcessGroupNCCL() { terminateWatchdog_.store(true); watchdogCV_.notify_one(); +#ifdef ENABLE_NCCL_ERROR_CHECKING ncclCommWatchdogThread_.join(); +#endif } void ProcessGroupNCCL::ncclCommWatchdog() { @@ -314,7 +318,7 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); if (ncclAsyncErr != ncclSuccess) { return std::make_exception_ptr(std::runtime_error( - "NCCL error: " + std::string(ncclGetErrorString(ncclAsyncErr)))); + "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr))); } } diff --git a/torch/lib/c10d/test/CMakeLists.txt b/torch/lib/c10d/test/CMakeLists.txt index 5ed12a4ef3f99..ff34bc952887f 100644 --- a/torch/lib/c10d/test/CMakeLists.txt +++ b/torch/lib/c10d/test/CMakeLists.txt @@ -23,7 +23,8 @@ if(USE_CUDA) endif() if(USE_C10D_NCCL) c10d_add_test(ProcessGroupNCCLTest.cpp c10d c10d_cuda_test) - c10d_add_test(ProcessGroupNCCLErrorsTest.cpp c10d c10d_cuda_test gtest_main) + c10d_add_test(ProcessGroupNCCLErrorsTest.cpp c10d c10d_cuda_test + gtest_main) endif() else() if(USE_C10D_GLOO) diff --git a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp index 4541111bbfd2c..9745324cfb011 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp @@ -3,9 +3,12 @@ #include #include #include +#include using namespace c10d::test; +constexpr int kNcclErrorHandlingVersion = 2400; + class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { public: WorkNCCLSimulateErrors( @@ -71,8 +74,14 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { class ProcessGroupNCCLErrorsTest : public ::testing::Test { protected: bool skipTest() { - // Skip test if no cuda devices found. - return cudaNumDevices() == 0; + if (cudaNumDevices() == 0) { + return true; + } +#ifdef USE_C10D_NCCL + return torch::cuda::nccl::version() < kNcclErrorHandlingVersion; +#else + return false; +#endif } void SetUp() override { diff --git a/torch/nn/_intrinsic/__init__.py b/torch/nn/intrinsic/__init__.py similarity index 100% rename from torch/nn/_intrinsic/__init__.py rename to torch/nn/intrinsic/__init__.py diff --git a/torch/nn/_intrinsic/modules/__init__.py b/torch/nn/intrinsic/modules/__init__.py similarity index 100% rename from torch/nn/_intrinsic/modules/__init__.py rename to torch/nn/intrinsic/modules/__init__.py diff --git a/torch/nn/_intrinsic/modules/fused.py b/torch/nn/intrinsic/modules/fused.py similarity index 100% rename from torch/nn/_intrinsic/modules/fused.py rename to torch/nn/intrinsic/modules/fused.py diff --git a/torch/nn/_intrinsic/qat/__init__.py b/torch/nn/intrinsic/qat/__init__.py similarity index 100% rename from torch/nn/_intrinsic/qat/__init__.py rename to torch/nn/intrinsic/qat/__init__.py diff --git a/torch/nn/_intrinsic/qat/modules/__init__.py b/torch/nn/intrinsic/qat/modules/__init__.py similarity index 100% rename from torch/nn/_intrinsic/qat/modules/__init__.py rename to torch/nn/intrinsic/qat/modules/__init__.py diff --git a/torch/nn/_intrinsic/qat/modules/conv_fused.py b/torch/nn/intrinsic/qat/modules/conv_fused.py similarity index 98% rename from torch/nn/_intrinsic/qat/modules/conv_fused.py rename to torch/nn/intrinsic/qat/modules/conv_fused.py index dab0bf75b6f28..6356f066fbb4f 100644 --- a/torch/nn/_intrinsic/qat/modules/conv_fused.py +++ b/torch/nn/intrinsic/qat/modules/conv_fused.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import torch import torch.nn as nn -import torch.nn._intrinsic +import torch.nn.intrinsic import torch.nn.qat as nnqat import torch.nn.functional as F from torch.nn import init @@ -28,7 +28,7 @@ class ConvBn2d(nn.Conv2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = torch.nn._intrinsic.ConvBn2d + _FLOAT_MODULE = torch.nn.intrinsic.ConvBn2d def __init__(self, # Conv2d args @@ -192,7 +192,7 @@ class ConvBnReLU2d(ConvBn2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = torch.nn._intrinsic.ConvBnReLU2d + _FLOAT_MODULE = torch.nn.intrinsic.ConvBnReLU2d def __init__(self, # Conv2d args @@ -236,7 +236,7 @@ class ConvReLU2d(nnqat.Conv2d): weight_fake_quant: fake quant module for weight """ - _FLOAT_MODULE = torch.nn._intrinsic.ConvReLU2d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, diff --git a/torch/nn/_intrinsic/qat/modules/linear_relu.py b/torch/nn/intrinsic/qat/modules/linear_relu.py similarity index 88% rename from torch/nn/_intrinsic/qat/modules/linear_relu.py rename to torch/nn/intrinsic/qat/modules/linear_relu.py index 04037555483af..a9c4db6db1b99 100644 --- a/torch/nn/_intrinsic/qat/modules/linear_relu.py +++ b/torch/nn/intrinsic/qat/modules/linear_relu.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals import torch.nn.qat as nnqat -import torch.nn._intrinsic +import torch.nn.intrinsic import torch.nn.functional as F class LinearReLU(nnqat.Linear): @@ -11,7 +11,7 @@ class LinearReLU(nnqat.Linear): We adopt the same interface as :class:`torch.nn.Linear`. - Similar to `torch.nn._intrinsic.LinearReLU`, with FakeQuantize modules initialized to + Similar to `torch.nn.intrinsic.LinearReLU`, with FakeQuantize modules initialized to default. Attributes: @@ -27,7 +27,7 @@ class LinearReLU(nnqat.Linear): >>> print(output.size()) torch.Size([128, 30]) """ - _FLOAT_MODULE = torch.nn._intrinsic.LinearReLU + _FLOAT_MODULE = torch.nn.intrinsic.LinearReLU def __init__(self, in_features, out_features, bias=True, qconfig=None): diff --git a/torch/nn/_intrinsic/quantized/__init__.py b/torch/nn/intrinsic/quantized/__init__.py similarity index 100% rename from torch/nn/_intrinsic/quantized/__init__.py rename to torch/nn/intrinsic/quantized/__init__.py diff --git a/torch/nn/_intrinsic/quantized/modules/__init__.py b/torch/nn/intrinsic/quantized/modules/__init__.py similarity index 100% rename from torch/nn/_intrinsic/quantized/modules/__init__.py rename to torch/nn/intrinsic/quantized/modules/__init__.py diff --git a/torch/nn/_intrinsic/quantized/modules/conv_relu.py b/torch/nn/intrinsic/quantized/modules/conv_relu.py similarity index 88% rename from torch/nn/_intrinsic/quantized/modules/conv_relu.py rename to torch/nn/intrinsic/quantized/modules/conv_relu.py index 53fca2781781b..4143a22613c9d 100644 --- a/torch/nn/_intrinsic/quantized/modules/conv_relu.py +++ b/torch/nn/intrinsic/quantized/modules/conv_relu.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import torch.nn.quantized as nnq -import torch.nn._intrinsic -import torch.nn._intrinsic.qat +import torch.nn.intrinsic +import torch.nn.intrinsic.qat from torch.nn.utils import fuse_conv_bn_weights import torch @@ -15,7 +15,7 @@ class ConvReLU2d(nnq.Conv2d): Same as torch.nn.quantized.Conv2d """ - _FLOAT_MODULE = torch.nn._intrinsic.ConvReLU2d + _FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, @@ -35,9 +35,12 @@ def forward(self, input): self.dilation, self.groups, self.scale, self.zero_point) + def _get_name(self): + return 'QuantizedConvReLU2d' + @classmethod def from_float(cls, mod): - if type(mod) == torch.nn._intrinsic.qat.ConvBnReLU2d: + if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU2d: mod.weight, mod.bias = \ fuse_conv_bn_weights(mod.weight, mod.bias, mod.running_mean, mod.running_var, mod.eps, mod.gamma, mod.beta) diff --git a/torch/nn/_intrinsic/quantized/modules/linear_relu.py b/torch/nn/intrinsic/quantized/modules/linear_relu.py similarity index 83% rename from torch/nn/_intrinsic/quantized/modules/linear_relu.py rename to torch/nn/intrinsic/quantized/modules/linear_relu.py index e18686dba6f55..28bb9d8c1a785 100644 --- a/torch/nn/_intrinsic/quantized/modules/linear_relu.py +++ b/torch/nn/intrinsic/quantized/modules/linear_relu.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals import torch.nn.quantized as nnq -import torch.nn._intrinsic +import torch.nn.intrinsic import torch class LinearReLU(nnq.Linear): @@ -14,13 +14,13 @@ class LinearReLU(nnq.Linear): Examples:: - >>> m = nn._intrinsic.LinearReLU(20, 30) + >>> m = nn.intrinsic.LinearReLU(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30]) """ - _FLOAT_MODULE = torch.nn._intrinsic.LinearReLU + _FLOAT_MODULE = torch.nn.intrinsic.LinearReLU def __init__(self, in_features, out_features, bias=True): super(LinearReLU, self).__init__(in_features, out_features, bias) @@ -32,6 +32,9 @@ def forward(self, input): int(self.zero_point)) return Y_q + def _get_name(self): + return 'QuantizedLinearReLU' + @classmethod def from_float(cls, mod): return super(LinearReLU, cls).from_float(mod) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 282d2dc30819d..7762768f3819b 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -68,16 +68,12 @@ def forward(self, x): _version = 1 def __init__(self): - self.__construct() - # initialize self.training separately from the rest of the internal - # state, as it is managed differently by nn.Module and ScriptModule - self.training = True - - def __construct(self): """ Initializes internal Module state, shared by both nn.Module and ScriptModule. """ torch._C._log_api_usage_once("python.nn_module") + + self.training = True self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() diff --git a/torch/nn/qat/modules/conv.py b/torch/nn/qat/modules/conv.py index d6f489d297375..0338738654b7c 100644 --- a/torch/nn/qat/modules/conv.py +++ b/torch/nn/qat/modules/conv.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, division, print_function, unicode_literals from torch.nn import Conv2d as NNConv2d -from torch.nn._intrinsic import ConvReLU2d +from torch.nn.intrinsic import ConvReLU2d class Conv2d(NNConv2d): r""" diff --git a/torch/nn/qat/modules/linear.py b/torch/nn/qat/modules/linear.py index 192f5cab0b33e..6ab87d6706781 100644 --- a/torch/nn/qat/modules/linear.py +++ b/torch/nn/qat/modules/linear.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import torch.nn as nn import torch.nn.functional as F -from torch.nn._intrinsic import LinearReLU +from torch.nn.intrinsic import LinearReLU class Linear(nn.Linear): r""" diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py index 79c9d0c6ff9cb..462c70776fcd7 100644 --- a/torch/nn/quantized/functional.py +++ b/torch/nn/quantized/functional.py @@ -8,54 +8,58 @@ from torch._jit_internal import List as _List from torch.nn.modules.utils import _pair +# Although some of the functions and docstrings are mirrored from the torch.nn, +# we want to have them here for future changes. -def relu(input, inplace=False): - # type: (Tensor, bool) -> Tensor - r"""relu(input, inplace=False) -> Tensor +def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False, + count_include_pad=True, divisor_override=None): + r""" + Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size + :math:`sH \times sW` steps. The number of output features is equal to the number of + input planes. - Applies the rectified linear unit function element-wise. See - :class:`~torch.nn.ReLU` for more details. + .. note:: The input quantization parameters propagate to the output. + + See :class:`~torch.nn.quantized.AvgPool2d` for details and output shape. + + Args: + input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` + kernel_size: size of the pooling region. Can be a single number or a + tuple `(kH, kW)` + stride: stride of the pooling operation. Can be a single number or a + tuple `(sH, sW)`. Default: :attr:`kernel_size` + padding: implicit zero paddings on both sides of the input. Can be a + single number or a tuple `(padH, padW)`. Default: 0 + ceil_mode: when True, will use `ceil` instead of `floor` in the formula + to compute the output shape. Default: ``False`` + count_include_pad: when True, will include the zero-padding in the + averaging calculation. Default: ``True`` + divisor_override: if specified, it will be used as divisor, otherwise + size of the pooling region will be used. Default: None """ if not input.is_quantized: - raise ValueError("Input to 'quantized.relu' must be quantized!") - if inplace: - return torch.relu_(input) - else: - return torch.relu(input) + raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!") + return torch.nn.functional.avg_pool2d(input, kernel_size, stride, padding, + ceil_mode, count_include_pad, + divisor_override) -def linear(input, weight, bias=None, scale=None, zero_point=None): - # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor +def adaptive_avg_pool2d(input, output_size): + # type: (Tensor, BroadcastingList2[int]) -> Tensor r""" - Applies a linear transformation to the incoming quantized data: - :math:`y = xA^T + b`. - See :class:`~torch.nn.Linear` + Applies a 2D adaptive average pooling over a quantized input signal composed + of several quantized input planes. - .. note:: + .. note:: The input quantization paramteres propagate to the output. - Current implementation uses packed weights. This has penalty on performance. - If you want to avoid the overhead, use :class:`~torch.nn.quantized.Linear`. + See :class:`~torch.nn.quantized.AdaptiveAvgPool2d` for details and output shape. Args: - input (Tensor): Quantized input of type `torch.quint8` - weight (Tensor): Quantized weight of type `torch.qint8` - bias (Tensor): None or fp32 bias of type `torch.float` - scale (double): output scale. If None, derived from the input scale - zero_point (long): output zero point. If None, derived from the input zero_point - - Shape: - - Input: :math:`(N, *, in\_features)` where `*` means any number of - additional dimensions - - Weight: :math:`(out\_features, in\_features)` - - Bias: :math:`(out\_features)` - - Output: :math:`(N, *, out\_features)` + output_size: the target output size (single integer or + double-integer tuple) """ - if scale is None: - scale = input.q_scale() - if zero_point is None: - zero_point = input.q_zero_point() - _packed_params = torch.ops.quantized.linear_prepack(weight, bias) - return torch.ops.quantized.linear(input, _packed_params, scale, - zero_point) + if not input.is_quantized: + raise ValueError("Input to 'quantized.adaptive_avg_pool2d' must be quantized!") + return torch.nn.functional.adaptive_avg_pool2d(input, output_size) def conv2d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1, @@ -63,16 +67,10 @@ def conv2d(input, weight, bias, scale=1.0, zero_point=0, dtype=torch.quint8): r""" - conv2d(input, weight, bias, - stride=1, padding=0, dilation=1, groups=1, - padding_mode='zeros', - scale=1.0, zero_point=0, - dtype=torch.quint8) -> Tensor - Applies a 2D convolution over a quantized 2D input composed of several input planes. - See :class:`~torch.nn.Conv2d` for details and output shape. + See :class:`~torch.nn.quantized.Conv2d` for details and output shape. Args: input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` @@ -120,11 +118,87 @@ def conv2d(input, weight, bias, stride, padding, dilation, groups, scale, zero_point) +def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None): + r"""Down/up samples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): the input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.interpolate' must be quantized!") + return torch.nn.functional.interpolate(input, size, scale_factor, mode, + align_corners) + +def linear(input, weight, bias=None, scale=None, zero_point=None): + # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor + r""" + Applies a linear transformation to the incoming quantized data: + :math:`y = xA^T + b`. + See :class:`~torch.nn.quantized.Linear` + + .. note:: + + Current implementation packs weights on every call, which has penalty on performance. + If you want to avoid the overhead, use :class:`~torch.nn.quantized.Linear`. + + Args: + input (Tensor): Quantized input of type `torch.quint8` + weight (Tensor): Quantized weight of type `torch.qint8` + bias (Tensor): None or fp32 bias of type `torch.float` + scale (double): output scale. If None, derived from the input scale + zero_point (long): output zero point. If None, derived from the input zero_point + + Shape: + - Input: :math:`(N, *, in\_features)` where `*` means any number of + additional dimensions + - Weight: :math:`(out\_features, in\_features)` + - Bias: :math:`(out\_features)` + - Output: :math:`(N, *, out\_features)` + """ + if scale is None: + scale = input.q_scale() + if zero_point is None: + zero_point = input.q_zero_point() + _packed_params = torch.ops.quantized.linear_prepack(weight, bias) + return torch.ops.quantized.linear(input, _packed_params, scale, zero_point) + def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): r"""Applies a 2D max pooling over a quantized input signal composed of several quantized input planes. + .. note:: The input quantization parameters are propagated to the output. + See :class:`~torch.nn.quantized.MaxPool2d` for details. """ if return_indices: @@ -134,7 +208,117 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode, return_indices) -# TODO(zaf): Add documentation -adaptive_avg_pool2d = torch.nn.functional.adaptive_avg_pool2d -interpolate = torch.nn.functional.interpolate -avg_pool2d = torch.nn.functional.avg_pool2d +def relu(input, inplace=False): + # type: (Tensor, bool) -> Tensor + r"""relu(input, inplace=False) -> Tensor + + Applies the rectified linear unit function element-wise. + See :class:`~torch.nn.quantized.ReLU` for more details. + + Args: + input: quantized input + inplace: perform the computation inplace + """ + if not input.is_quantized: + raise ValueError("Input to 'quantized.relu' must be quantized!") + if inplace: + return torch.relu_(input) + else: + return torch.relu(input) + +def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None): + r"""Upsamples the input to either the given :attr:`size` or the given + :attr:`scale_factor` + + .. warning:: + This function is deprecated in favor of + :func:`torch.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(...)``. + + See :func:`torch.nn.functional.interpolate` for implementation details. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D input is supported for quantized inputs + + .. note:: Only the following modes are supported for the quantized inputs: + + - `bilinear` + - `nearest` + + Args: + input (Tensor): quantized input tensor + size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]): + output spatial size. + scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer. + mode (string): algorithm used for upsampling: + ``'nearest'`` | ``'bilinear'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values, making this operation *independent* of input size + when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode` + is ``'bilinear'``. + Default: ``False`` + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`bilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + """ + warnings.warn("nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead.") + return interpolate(input, size, scale_factor, mode, align_corners) + +def upsample_bilinear(input, size=None, scale_factor=None): + r"""Upsamples the input, using bilinear upsampling. + + .. warning:: + This function is deprecated in favor of + :func:`torch.nn.quantized.functional.interpolate`. + This is equivalent with + ``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int]): output spatial size. + scale_factor (int or Tuple[int, int]): multiplier for spatial size + """ + # DeprecationWarning is ignored by default + warnings.warn("nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead.") + return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True) + +def upsample_nearest(input, size=None, scale_factor=None): + r"""Upsamples the input, using nearest neighbours' pixel values. + + .. warning:: + This function is deprecated in favor of + :func:`torch.nn.quantized.functional.interpolate`. + This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``. + + .. note:: The input quantization parameters propagate to the output. + + .. note:: Only 2D inputs are supported + + Args: + input (Tensor): quantized input + size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial + size. + scale_factor (int): multiplier for spatial size. Has to be an integer. + """ + # DeprecationWarning is ignored by default + warnings.warn("nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead.") + return interpolate(input, size, scale_factor, mode='nearest') diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index b31c314e774fc..2ae5248c26a72 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn -import torch.nn._intrinsic as nni -import torch.nn._intrinsic.qat as nniqat +import torch.nn.intrinsic as nni +import torch.nn.intrinsic.qat as nniqat from torch.nn.utils import fuse_conv_bn_weights from torch._ops import ops from torch.nn.modules.utils import _pair @@ -166,7 +166,8 @@ def __getstate__(self): w, b, self.scale, - self.zero_point + self.zero_point, + self.training ) # ===== Deserialization methods ===== @@ -190,7 +191,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, @torch.jit.export def __setstate__(self, state): - # type: (Tuple[int, int, Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int], bool, int, int, str, Tensor, Optional[Tensor], float, int]) # noqa + # type: (Tuple[int, int, Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int], bool, int, int, str, Tensor, Optional[Tensor], float, int, bool]) # noqa self.in_channels = state[0] self.out_channels = state[1] self.kernel_size = state[2] @@ -204,6 +205,7 @@ def __setstate__(self, state): self.set_weight_bias(state[10], state[11]) self.scale = state[12] self.zero_point = state[13] + self.training = state[14] @classmethod def from_float(cls, mod): diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index 3ff3d6f5e79cd..449d19ef9a936 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -2,9 +2,9 @@ import torch -from torch._jit_internal import Optional +from torch._jit_internal import Optional # noqa: F401 import torch.nn as nn -import torch.nn._intrinsic as nni +import torch.nn.intrinsic as nni from torch.nn.modules import Module from torch.nn.quantized.modules.utils import _quantize_weight @@ -158,7 +158,8 @@ def __getstate__(self): b, w, self.scale, - self.zero_point + self.zero_point, + self.training ) # ===== Deserialization methods ===== @@ -181,12 +182,13 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, @torch.jit.export def __setstate__(self, state): - # type: (Tuple[int, int, Optional[torch.Tensor], torch.Tensor, float, int]) -> None + # type: (Tuple[int, int, Optional[torch.Tensor], torch.Tensor, float, int, bool]) -> None self.in_features = state[0] self.out_features = state[1] self.set_weight_bias(state[3], state[2]) self.scale = state[4] self.zero_point = state[5] + self.training = state[6] # Function rather than property to make sure that JIT serialization doesn't # register this as an attribute diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index bfb9768a65dff..7a4637c3232c7 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -100,95 +100,49 @@ def pin_memory(self): bind(self.unsorted_indices, lambda t: t.pin_memory())) def cuda(self, *args, **kwargs): - """Returns a GPU copy if `self.data` not already on the GPU""" - if self.is_cuda: - return self - else: - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.cuda(*args, **kwargs), self.batch_sizes, - bind(self.sorted_indices, lambda t: t.cuda(*args, **kwargs)), - bind(self.unsorted_indices, lambda t: t.cuda(*args, **kwargs))) - - def cpu(self): - """Returns a CPU copy if `self.data` not already on the CPU""" - if self.is_cuda: - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.cpu(), self.batch_sizes, - bind(self.sorted_indices, lambda t: t.cpu()), - bind(self.unsorted_indices, lambda t: t.cpu())) - else: - return self + # Tests to see if 'cuda' should be added to kwargs + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs) + if ex.is_cuda: + return self.to(*args, **kwargs) + return self.to(*args, device='cuda', **kwargs) - def double(self): - r"""Returns copy with `self.data` cast to double type""" + def cpu(self, *args, **kwargs): - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.double(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs) + if ex.device.type == 'cpu': + return self.to(*args, **kwargs) + return self.to(*args, device='cpu', **kwargs) - def float(self): - r"""Returns copy with `self.data` cast to float type""" + def double(self): + return self.to(dtype=torch.double) - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.float(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + def float(self): + return self.to(dtype=torch.float) def half(self): - r"""Returns copy with `self.data` cast to half type""" - - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.half(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + return self.to(dtype=torch.half) def long(self): - r"""Returns copy with `self.data` cast to long type""" - - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.long(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + return self.to(dtype=torch.long) def int(self): - r"""Returns copy with `self.data` cast to int type""" - - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.int(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + return self.to(dtype=torch.int) def short(self): - r"""Returns copy with `self.data` cast to short type""" - - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.short(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + return self.to(dtype=torch.short) def char(self): - r"""Returns copy with `self.data` cast to char type""" - - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.char(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + return self.to(dtype=torch.int8) def byte(self): - r"""Returns copy with `self.data` cast to byte type""" - - # Why not convert `batch_sizes`? - # See NOTE [ device and dtype of a PackedSequence ] - return type(self)(self.data.byte(), self.batch_sizes, - self.sorted_indices, self.unsorted_indices) + return self.to(dtype=torch.uint8) def to(self, *args, **kwargs): r"""Performs dtype and/or device conversion on `self.data`. - It has similar signature as :meth:`torch.Tensor.to`. + It has similar signature as :meth:`torch.Tensor.to`, except optional + arguments like `non_blocking` and `copy` should be passed as kwargs, + not args, or they will not apply to the index tensors. .. note:: @@ -200,17 +154,14 @@ def to(self, *args, **kwargs): # Why not convert `batch_sizes`? # See NOTE [ device and dtype of a PackedSequence ] data = self.data.to(*args, **kwargs) - sorted_indices = self.sorted_indices - unsorted_indices = self.unsorted_indices - device_kw = 'device' - if device_kw in kwargs: - sorted_indices = bind(sorted_indices, lambda t: t.to(kwargs[device_kw])) - unsorted_indices = bind(unsorted_indices, lambda t: t.to(kwargs[device_kw])) if data is self.data: return self else: - return type(self)(data, self.batch_sizes, - sorted_indices, unsorted_indices) + # Does not forward device or dtype arg/kwargs, device is set from data.device + kwargs = {k : v for k, v in filter(lambda t: t[0] != 'device' and t[0] != 'dtype', kwargs.items())} + sorted_indices = bind(self.sorted_indices, lambda t: t.to(data.device, **kwargs)) + unsorted_indices = bind(self.unsorted_indices, lambda t: t.to(data.device, **kwargs)) + return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices) @property def is_cuda(self): diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index d1966bc5ff7c1..048e69c36c19d 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -194,6 +194,11 @@ def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False) return _slice(g, input, axes, starts, ends, steps, dynamic_slice) +def _is_fp(value): + type = value.type().scalarType() + return (type == 'Float') or (type == 'Double') or (type == 'Half') + + def _sort_helper(g, input, dim, decending=True, out=None): if out is not None: _unimplemented("Sort", "Out parameter is not supported") @@ -237,6 +242,7 @@ def _interpolate_warning(interpolate_mode): "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n" "We recommend using opset 11 and above for models using this operator. ") + def _interpolate_size_to_scales(g, input, output_size, dim): output_size = _maybe_get_const(output_size, 'is') if _is_value(output_size): @@ -254,7 +260,6 @@ def _interpolate_size_to_scales(g, input, output_size, dim): scales = g.op("Constant", value_t=torch.tensor(scales_constant)) return scales - def _scatter_helper(g, self, dim, index, src): if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 241c9dbc4c587..d0ce9e21a4892 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -168,3 +168,7 @@ def flip(g, input, dims): starts=[-1] * len(dims), ends=[-9223372036854775807] * len(dims), steps=[-1] * len(dims)) + + +def fmod(g, input, other): + return g.op("Mod", input, other, fmod_i=1) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 7251d86095e75..19dd546ecc832 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -937,7 +937,21 @@ def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_s @parse_args('v', 'i', 'i', 'i') def unfold(g, input, dimension, size, step): - return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) + if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: + return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step) + if input.isCompleteTensor(): + sizedim = input.type().sizes()[dimension] + low_indices = range(0, sizedim, step) + hi_indices = range(size, sizedim + 1, step) + stack = [sym_help._slice_helper(g, input, axes=[dimension], starts=[low], ends=[hi]) + for low, hi in zip(low_indices, hi_indices)] + ndim = input.type().dim() + perm = list(range(0, ndim)) + perm.append(perm.pop(dimension)) + unsqueeze = [g.op("Unsqueeze", g.op("Transpose", t, perm_i=perm), axes_i=[dimension]) for t in stack] + return g.op("Concat", *unsqueeze, axis_i=dimension) + else: + return _unimplemented("Unfold", "input size not accessible") @parse_args('v', 'v', 'i') @@ -972,7 +986,7 @@ def index_select(g, self, dim, index): index = g.op("Constant", value_t=torch.LongTensor([index_const])) elif index_dim is not None: if index_dim == 0: - # Index is a scalar. Reshape it to a size 1 tensor. + # Index is a scalar. Reshape it to a size 1 tensor. index = g.op("Reshape", index, g.op("Constant", value_t=torch.LongTensor([1]))) return g.op("Gather", self, index, axis_i=dim) @@ -1042,7 +1056,7 @@ def cosine_similarity(g, x1, x2, dim, eps): # ignore clone operators that are inserted by PyTorch autograd -def clone(g, input): +def clone(g, input, unused_memory_format): return input @@ -1950,13 +1964,21 @@ def multinomial(g, input, num_samples, replacement=False, generator=None): dtype_i=sym_help.cast_pytorch_to_onnx['Long'], sample_size_i=num_samples) -def baddbmm(g, self, batch1, batch2, beta, alpha): - dtype = self.type().scalarType() - batch_mul = matmul(g, batch1, batch2) - mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype])) +def baddbmm(g, self, batch1, batch2, beta, alpha): + dtype = self.type().scalarType() + batch_mul = matmul(g, batch1, batch2) + mul_a = mul(g, batch_mul, g.op("Cast", alpha, to_i=sym_help.cast_pytorch_to_onnx[dtype])) mul_b = mul(g, self, g.op("Cast", beta, to_i=sym_help.cast_pytorch_to_onnx[dtype])) return add(g, mul_a, mul_b) +def remainder(g, input, other): + div = g.op("Div", input, other) + if sym_help._is_fp(input): + div = g.op("Floor", div) + quo = g.op("Mul", div, other) + return g.op("Sub", input, quo) + + def gelu(g, self): _sqrt2 = 1.4142135623730951 erf = g.op('Erf', div(g, self, torch.tensor(_sqrt2))) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index b5091839a2093..3ee099b3e26ff 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -238,8 +238,8 @@ def _model_to_graph(model, args, verbose=False, training=False, method_graph, tuple(in_vars), False, propagate) except AttributeError: raise RuntimeError('\'forward\' method must be a script method') - elif isinstance(model, torch.jit.Function): - assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function" + elif isinstance(model, torch.jit.ScriptFunction): + assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript ScriptFunction" method = model params = () in_vars, in_desc = torch.jit._flatten(tuple(args)) @@ -261,7 +261,7 @@ def _model_to_graph(model, args, verbose=False, training=False, _disable_torch_constant_prop=_disable_torch_constant_prop, fixed_batch_size=fixed_batch_size) - if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.Function): + if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction): out_vars, _ = torch.jit._flatten(tuple(example_outputs)) graph = _assign_output_shapes(graph, out_vars) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 8e75a4b7667bb..64d342ddc7eb0 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -4,21 +4,36 @@ from functools import wraps import warnings import weakref +from collections import Counter from bisect import bisect_right from .optimizer import Optimizer +EPOCH_DEPRECATION_WARNING = ( + "The epoch parameter in `scheduler.step()` was not necessary and is being " + "deprecated where possible. Please use `scheduler.step()` to step the " + "scheduler. During the deprecation, if epoch is different from None, the " + "closed form is used instead of the new chainable form, where available. " + "Please open an issue if you are unable to replicate your use case: " + "https://github.com/pytorch/pytorch/issues/new/choose." +) + + class _LRScheduler(object): + def __init__(self, optimizer, last_epoch=-1): + + # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) self.optimizer = optimizer + + # Initialize epoch and base learning rates if last_epoch == -1: for group in optimizer.param_groups: group.setdefault('initial_lr', group['lr']) - last_epoch = 0 else: for i, group in enumerate(optimizer.param_groups): if 'initial_lr' not in group: @@ -58,7 +73,8 @@ def wrapper(*args, **kwargs): self.optimizer.step = with_counter(self.optimizer.step) self.optimizer._step_count = 0 self._step_count = 0 - self.step(last_epoch) + + self.step() def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. @@ -77,7 +93,13 @@ def load_state_dict(self, state_dict): """ self.__dict__.update(state_dict) + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + def get_lr(self): + # Compute learning rate using chainable form of the scheduler raise NotImplementedError def step(self, epoch=None): @@ -100,12 +122,36 @@ def step(self, epoch=None): "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) self._step_count += 1 - if epoch is None: - epoch = self.last_epoch + 1 - self.last_epoch = epoch - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + return self + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, DeprecationWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = self._get_closed_form_lr() + else: + values = self.get_lr() + + for param_group, lr in zip(self.optimizer.param_groups, values): param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + class LambdaLR(_LRScheduler): """Sets the learning rate of each parameter group to the initial lr @@ -131,6 +177,7 @@ class LambdaLR(_LRScheduler): def __init__(self, optimizer, lr_lambda, last_epoch=-1): self.optimizer = optimizer + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: @@ -173,14 +220,19 @@ def load_state_dict(self, state_dict): self.lr_lambdas[idx].__dict__.update(fn) def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + return [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] class StepLR(_LRScheduler): - """Sets the learning rate of each parameter group to the initial lr - decayed by gamma every step_size epochs. When last_epoch=-1, sets - initial lr as lr. + """Decays the learning rate of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the learning rate from outside this scheduler. When + last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. @@ -208,14 +260,25 @@ def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): super(StepLR, self).__init__(optimizer, last_epoch) def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): return [base_lr * self.gamma ** (self.last_epoch // self.step_size) for base_lr in self.base_lrs] class MultiStepLR(_LRScheduler): - """Set the learning rate of each parameter group to the initial lr decayed - by gamma once the number of epoch reaches one of the milestones. When - last_epoch=-1, sets initial lr as lr. + """Decays the learning rate of each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside + this scheduler. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. @@ -237,21 +300,28 @@ class MultiStepLR(_LRScheduler): """ def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): - if not list(milestones) == sorted(milestones): - raise ValueError('Milestones should be a list of' - ' increasing integers. Got {}', milestones) - self.milestones = milestones + self.milestones = Counter(milestones) self.gamma = gamma super(MultiStepLR, self).__init__(optimizer, last_epoch) def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) for base_lr in self.base_lrs] class ExponentialLR(_LRScheduler): - """Set the learning rate of each parameter group to the initial lr decayed - by gamma every epoch. When last_epoch=-1, sets initial lr as lr. + """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. @@ -264,6 +334,16 @@ def __init__(self, optimizer, gamma, last_epoch=-1): super(ExponentialLR, self).__init__(optimizer, last_epoch) def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + + if self.last_epoch == 0: + return self.base_lrs + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): return [base_lr * self.gamma ** self.last_epoch for base_lr in self.base_lrs] @@ -273,12 +353,23 @@ class CosineAnnealingLR(_LRScheduler): schedule, where :math:`\eta_{max}` is set to the initial lr and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + .. math:: + \eta_{t+1} = \eta_{min} + (\eta_t - \eta_{min})\frac{1 + + \cos(\frac{T_{cur}+1}{T_{max}}\pi)}{1 + \cos(\frac{T_{cur}}{T_{max}}\pi)}, + T_{cur} \neq (2k+1)T_{max};\\ + \eta_{t+1} = \eta_{t} + (\eta_{max} - \eta_{min})\frac{1 - + \cos(\frac{1}{T_{max}}\pi)}{2}, + T_{cur} = (2k+1)T_{max}.\\ + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + .. math:: \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + \cos(\frac{T_{cur}}{T_{max}}\pi)) - When last_epoch=-1, sets initial lr as lr. - It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only implements the cosine annealing part of SGDR, and not the restarts. @@ -299,6 +390,23 @@ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + + if self.last_epoch == 0: + return self.base_lrs + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return [group['lr'] + (base_lr - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in + zip(self.base_lrs, self.optimizer.param_groups)] + return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / + (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for base_lr in self.base_lrs] @@ -361,6 +469,7 @@ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, raise ValueError('Factor should be < 1.0.') self.factor = factor + # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) @@ -385,7 +494,7 @@ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, self.num_bad_epochs = None self.mode_worse = None # the worse value for the chosen mode self.eps = eps - self.last_epoch = -1 + self.last_epoch = 0 self._init_is_better(mode=mode, threshold=threshold, threshold_mode=threshold_mode) self._reset() @@ -400,7 +509,9 @@ def step(self, metrics, epoch=None): # convert `metrics` to float, in case it's a zero-dim Tensor current = float(metrics) if epoch is None: - epoch = self.last_epoch = self.last_epoch + 1 + epoch = self.last_epoch + 1 + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, DeprecationWarning) self.last_epoch = epoch if self.is_better(current, self.best): @@ -418,6 +529,8 @@ def step(self, metrics, epoch=None): self.cooldown_counter = self.cooldown self.num_bad_epochs = 0 + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + def _reduce_lr(self, epoch): for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group['lr']) @@ -579,6 +692,7 @@ def __init__(self, max_momentum=0.9, last_epoch=-1): + # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) @@ -658,6 +772,11 @@ def get_lr(self): If `self.cycle_momentum` is ``True``, this function has a side effect of updating the optimizer's momentum. """ + + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + cycle = math.floor(1 + self.last_epoch / self.total_size) x = 1. + self.last_epoch / self.total_size - cycle if x <= self.step_ratio: @@ -725,10 +844,16 @@ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): self.T_i = T_0 self.T_mult = T_mult self.eta_min = eta_min + super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) + self.T_cur = self.last_epoch def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 for base_lr in self.base_lrs] @@ -757,6 +882,10 @@ def step(self, epoch=None): >>> scheduler.step(26) >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) """ + + if epoch is None and self.last_epoch < 0: + epoch = 0 + if epoch is None: epoch = self.last_epoch + 1 self.T_cur = self.T_cur + 1 @@ -777,8 +906,26 @@ def step(self, epoch=None): self.T_i = self.T_0 self.T_cur = epoch self.last_epoch = math.floor(epoch) - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr + + class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + return self + + with _enable_get_lr_call(self): + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + class OneCycleLR(_LRScheduler): r"""Sets the learning rate of each parameter group according to the @@ -970,6 +1117,10 @@ def _annealing_linear(self, start, end, pct): return (end - start) * pct + start def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", DeprecationWarning) + lrs = [] step_num = self.last_epoch diff --git a/torch/quantization/QConfig.py b/torch/quantization/QConfig.py index 0f52e0908ee7d..d1a9230703de7 100644 --- a/torch/quantization/QConfig.py +++ b/torch/quantization/QConfig.py @@ -86,13 +86,13 @@ def get_default_qconfig(backend='fbgemm'): def get_default_qat_qconfig(backend='fbgemm'): # Histogram observer is too slow for quantization aware training if backend == 'fbgemm': - qconfig = QConfig(activation=FakeQuantize.with_args(observer=MinMaxObserver, + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=True), weight=default_per_channel_weight_fake_quant) elif backend == 'qnnpack': - qconfig = QConfig(activation=FakeQuantize.with_args(observer=MinMaxObserver, + qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, reduce_range=False), diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index ad498381f2143..724e6e0ca90b6 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -4,6 +4,7 @@ from .QConfig import * # noqa: F401 from .fake_quantize import * # noqa: F401 from .fuse_modules import fuse_modules # noqa: F401 +from .stubs import * # noqa: F401 def default_eval_fn(model, calib_data): r""" @@ -14,7 +15,7 @@ def default_eval_fn(model, calib_data): model(data) _all__ = [ - 'QuantWrapper', 'QuantStub', 'DeQuantStub', 'DEFAULT_MODULE_MAPPING', + 'QuantWrapper', 'QuantStub', 'DeQuantStub', # Top level API for eager mode quantization 'quantize', # Sub functions used by eager mode quantization diff --git a/torch/quantization/_quantize_script.py b/torch/quantization/_quantize_script.py index aa83dd813d0b2..722f04723ae24 100644 --- a/torch/quantization/_quantize_script.py +++ b/torch/quantization/_quantize_script.py @@ -24,11 +24,13 @@ def forward(self, x): @torch.jit.export def __getstate__(self): - return self._weight_bias() + return self._weight_bias(), self.training @torch.jit.export def __setstate__(self, state): - self.set_weight_bias(state[0], state[1]) + # type: (Tuple[Tuple[Tensor, Optional[Tensor]], bool]) -> None + self.set_weight_bias(state[0][0], state[0][1]) + self.training = state[1] def _check_is_script_module(model): if not isinstance(model, torch.jit.ScriptModule): diff --git a/torch/quantization/default_mappings.py b/torch/quantization/default_mappings.py new file mode 100644 index 0000000000000..5b6149937da1b --- /dev/null +++ b/torch/quantization/default_mappings.py @@ -0,0 +1,65 @@ + +from torch import nn + +import torch.nn.intrinsic as nni +import torch.nn.intrinsic.quantized as nniq +import torch.nn.intrinsic.qat as nniqat +import torch.nn.quantized as nnq +import torch.nn.quantized.dynamic as nnqd +import torch.nn.qat as nnqat + +from .stubs import QuantStub, DeQuantStub + +# Map for swapping float module to quantized ones +DEFAULT_MODULE_MAPPING = { + nn.Linear: nnq.Linear, + nn.ReLU: nnq.ReLU, + nn.Conv2d: nnq.Conv2d, + QuantStub: nnq.Quantize, + DeQuantStub: nnq.DeQuantize, + # Wrapper Modules: + nnq.FloatFunctional: nnq.QFunctional, + # Intrinsic modules: + nni.ConvReLU2d: nniq.ConvReLU2d, + nni.LinearReLU: nniq.LinearReLU, + nniqat.ConvReLU2d: nniq.ConvReLU2d, + nniqat.LinearReLU: nniq.LinearReLU, + nniqat.ConvBn2d: nnq.Conv2d, + nniqat.ConvBnReLU2d: nniq.ConvReLU2d, + # QAT modules: + nnqat.Linear: nnq.Linear, + nnqat.Conv2d: nnq.Conv2d, +} + +# Map for swapping float module to qat modules +DEFAULT_QAT_MODULE_MAPPING = { + nn.Linear: nnqat.Linear, + nn.Conv2d: nnqat.Conv2d, + # Intrinsic modules: + nni.ConvBn2d: nniqat.ConvBn2d, + nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, + nni.ConvReLU2d: nniqat.ConvReLU2d, + nni.LinearReLU: nniqat.LinearReLU +} + +# Map for swapping dynamic modules +DEFAULT_DYNAMIC_MODULE_MAPPING = { + nn.Linear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, +} + +# Whitelist for propagating the qconfig +_EXCLUDE_QCONFIG_PROPAGATE_LIST = { + DeQuantStub, +} +_INCLUDE_QCONFIG_PROPAGATE_LIST = { + nn.Sequential, +} + +DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST = ( + set(DEFAULT_MODULE_MAPPING.keys()) | + set(DEFAULT_QAT_MODULE_MAPPING.keys()) | + set(DEFAULT_DYNAMIC_MODULE_MAPPING.keys()) | + _INCLUDE_QCONFIG_PROPAGATE_LIST - + _EXCLUDE_QCONFIG_PROPAGATE_LIST +) diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index 2dc6d12ed8666..5c1bb274964b8 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -1,10 +1,10 @@ from __future__ import absolute_import, division, print_function, unicode_literals import torch from torch.nn import Module -from .observer import MinMaxObserver, HistogramObserver, PerChannelMinMaxObserver, _with_args +from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args class FakeQuantize(Module): - ''' Simulate the quantize and dequantize operations in training time. + r""" Simulate the quantize and dequantize operations in training time. The output of this module is given by x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale @@ -25,13 +25,13 @@ class FakeQuantize(Module): * :attr:`observer_enable` controls statistics collection on tensors * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization, - allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should - be chosen to be consistent with the dtype + allowable values are torch.qint8 and torch.quint8. The values of quant_min and + quant_max should be chosen to be consistent with the dtype Args: - observer (module): Module for observing statistics on input tensors and calculating - scale and zero-point. + observer (module): Module for observing statistics on input tensors and calculating scale + and zero-point. quant_min (int): The minimum allowable quantized value. quant_max (int): The maximum allowable quantized value. observer_kwargs (optional): Arguments for the observer module @@ -41,14 +41,7 @@ class FakeQuantize(Module): provides a method to calculate scale and zero-point. """ - Args: - `observer`: Observer module that records stats of input tensor - `quant_min`: Tensors are fake-quantized corresponding to the - `quant_max`: A function that calculates quantization parameters - given the stats - `observer_kwargs` - ''' - def __init__(self, observer=MinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): + def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs): super(FakeQuantize, self).__init__() assert quant_min <= quant_max, \ 'quant_min must be less than or equal to quant_max' @@ -119,12 +112,12 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) -default_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver, quant_min=0, quant_max=255, +default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True) -default_weight_fake_quant = FakeQuantize.with_args(observer=MinMaxObserver, quant_min=-128, quant_max=127, +default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False) -default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=PerChannelMinMaxObserver, +default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver, quant_min=-128, quant_max=127, dtype=torch.qint8, diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index 6784994a2cb77..612af215ff07d 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -3,7 +3,7 @@ import torch import copy -import torch.nn._intrinsic.modules.fused as torch_fused +import torch.nn.intrinsic.modules.fused as torch_fused def fuse_conv_bn(conv, bn): r"""Given the conv and bn modules, fuses them and returns the fused module @@ -26,7 +26,7 @@ def fuse_conv_bn(conv, bn): assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d' assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True' assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True' - return torch.nn._intrinsic.ConvBn2d(conv, bn) + return torch.nn.intrinsic.ConvBn2d(conv, bn) else: return torch.nn.utils.fuse_conv_bn_eval(conv, bn) @@ -87,8 +87,8 @@ def fuse_known_modules(mod_list): OP_LIST_TO_FUSER_METHOD = { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn, (torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu, - (torch.nn.Conv2d, torch.nn.ReLU): torch.nn._intrinsic.ConvReLU2d, - (torch.nn.Linear, torch.nn.ReLU): torch.nn._intrinsic.LinearReLU + (torch.nn.Conv2d, torch.nn.ReLU): torch.nn.intrinsic.ConvReLU2d, + (torch.nn.Linear, torch.nn.ReLU): torch.nn.intrinsic.LinearReLU } types = tuple(type(m) for m in mod_list) diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 83552d132f6d8..6aebb046d7f75 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -174,7 +174,7 @@ class MinMaxObserver(_ObserverBase): r"""Default Observer Module A default implementation of the observer module, only works for `per_tensor_affine` quantization scheme. The module will record the - running average of max and min value of the observed Tensor and + running average of max and min value of the observed Tensor and calculate_qparams will calculate scale and zero_point """ @@ -215,7 +215,7 @@ def forward(self, x_orig): max_val = torch.max(torch.max(x), max_val) self.min_val = min_val self.max_val = max_val - return x + return x_orig @torch.jit.export def calculate_qparams(self): @@ -238,6 +238,25 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super(MinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) +class MovingAverageMinMaxObserver(MinMaxObserver): + def __init__(self, averaging_constant=0.01, **kwargs): + self.averaging_constant = averaging_constant + super(MovingAverageMinMaxObserver, self).__init__(**kwargs) + + def forward(self, x_orig): + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + if min_val is None or max_val is None: + min_val = torch.min(x) + max_val = torch.max(x) + else: + min_val = min_val + self.averaging_constant * (torch.min(x) - min_val) + max_val = max_val + self.averaging_constant * (torch.max(x) - max_val) + self.min_val = min_val + self.max_val = max_val + return x_orig + class PerChannelMinMaxObserver(_ObserverBase): r"""Per Channel Observer Module @@ -260,26 +279,26 @@ def __init__(self, ch_axis=0, **kwargs): "Cannot reduce range for symmetric quantization for quint8" ) - def forward(self, x): - with torch.no_grad(): - min_vals = self.min_vals - max_vals = self.max_vals - x_dim = x.size() - - new_axis_list = list(range(len(x_dim))) - new_axis_list[self.ch_axis] = 0 - new_axis_list[0] = self.ch_axis - y = x.permute(tuple(new_axis_list)) - y = torch.flatten(y, start_dim=1) - if min_vals is None or max_vals is None: - min_vals = torch.min(y, 1)[0] - max_vals = torch.max(y, 1)[0] - else: - min_vals = torch.min(torch.min(y, 1)[0], min_vals) - max_vals = torch.max(torch.max(y, 1)[0], max_vals) - self.min_vals = min_vals - self.max_vals = max_vals - return x + def forward(self, x_orig): + x = x_orig.detach() # avoid keeping autograd tape + min_vals = self.min_vals + max_vals = self.max_vals + x_dim = x.size() + + new_axis_list = list(range(len(x_dim))) + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(tuple(new_axis_list)) + y = torch.flatten(y, start_dim=1) + if min_vals is None or max_vals is None: + min_vals = torch.min(y, 1)[0] + max_vals = torch.max(y, 1)[0] + else: + min_vals = torch.min(torch.min(y, 1)[0], min_vals) + max_vals = torch.max(torch.max(y, 1)[0], max_vals) + self.min_vals = min_vals + self.max_vals = max_vals + return x_orig def calculate_qparams(self): return self._calculate_per_channel_qparams(self.min_vals, self.max_vals) @@ -296,6 +315,38 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, super(PerChannelMinMaxObserver, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) +class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): + r"""Per Channel Observer Module + The module will record the running average of max and min value for each + channel of the observed Tensor and calculate_qparams will calculate + scales and zero_points for each channel + """ + + def __init__(self, averaging_constant=0.01, **kwargs): + self.averaging_constant = averaging_constant + super(MovingAveragePerChannelMinMaxObserver, self).__init__(**kwargs) + + def forward(self, x_orig): + x = x_orig.detach() # avoid keeping autograd tape + min_vals = self.min_vals + max_vals = self.max_vals + x_dim = x.size() + + new_axis_list = list(range(len(x_dim))) + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(tuple(new_axis_list)) + y = torch.flatten(y, start_dim=1) + if min_vals is None or max_vals is None: + min_vals = torch.min(y, 1)[0] + max_vals = torch.max(y, 1)[0] + else: + min_vals = min_vals + self.averaging_constant * (torch.min(y, 1)[0] - min_vals) + max_vals = max_vals + self.averaging_constant * (torch.max(y, 1)[0] - max_vals) + self.min_vals = min_vals + self.max_vals = max_vals + return x_orig + class HistogramObserver(_ObserverBase): r""" The module records the running histogram of tensor values along with diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index f1b1828820a7c..ebcdedb19e8f5 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -1,52 +1,30 @@ from __future__ import absolute_import, division, print_function, unicode_literals import copy +import itertools +import warnings + import torch import torch.nn as nn -import torch.nn._intrinsic as nni -import torch.nn._intrinsic.quantized as nniq -import torch.nn._intrinsic.qat as nniqat +import torch.nn.intrinsic as nni import torch.nn.quantized as nnq -import torch.nn.quantized.dynamic as nnqd -from .QConfig import default_dynamic_qconfig, float16_dynamic_qconfig -import torch.nn.qat as nnqat -import warnings - -class QuantStub(nn.Module): - r"""Quantize stub module, before calibration, this is same as an observer, - it will be swapped as `nnq.Quantize` in `convert`. - - Args: - qconfig: quantization configuration for the tensor, - if qconfig is not provided, we will get qconfig from parent modules - """ - def __init__(self, qconfig=None): - super(QuantStub, self).__init__() - if qconfig: - self.qconfig = qconfig - def forward(self, x): - return x - -class DeQuantStub(nn.Module): - r"""Dequantize stub module, before calibration, this is same as identity, - this will be swapped as `nnq.DeQuantize` in `convert`. - """ - def __init__(self): - super(DeQuantStub, self).__init__() - - def forward(self, x): - return x - -DEFAULT_SKIP_LIST = [nn.Dropout, nn.Identity, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, DeQuantStub] +from .default_mappings import (DEFAULT_DYNAMIC_MODULE_MAPPING, + DEFAULT_MODULE_MAPPING, + DEFAULT_QAT_MODULE_MAPPING, + DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST) +from .stubs import DeQuantStub, QuantWrapper +from .QConfig import default_dynamic_qconfig, float16_dynamic_qconfig -def _propagate_qconfig_helper(module, qconfig_dict, skip_list=DEFAULT_SKIP_LIST, qconfig_parent=None, prefix=''): +def _propagate_qconfig_helper(module, qconfig_dict, white_list=None, + qconfig_parent=None, prefix=''): r"""This is a helper function for `propagate_qconfig_` Args: module: input module qconfig_dict: dictionary that maps from name of submodule to quantization configuration + white_list: list of quantizable modules qconfig_parent: quantization config of parent module, we will fallback to this config when there is no specified config for current module @@ -56,39 +34,32 @@ def _propagate_qconfig_helper(module, qconfig_dict, skip_list=DEFAULT_SKIP_LIST, Return: None, module is modified inplace with qconfig attached """ - if type(module) in skip_list: - module.qconfig = None - if not hasattr(module, 'qconfig'): - module.qconfig = qconfig_parent - if qconfig_dict: - if prefix in qconfig_dict: - module.qconfig = qconfig_dict[prefix] - elif type(module) in qconfig_dict: - module.qconfig = qconfig_dict[type(module)] - - # Don't quantize empty Sequential, empty Sequential is same as - # Identity, but we can't put Sequential into skip list because - # we also have non-empty Sequential and the qconfig needs to - # be propagated to child in that case # TODO: Add test - if len(module._modules) == 0 and type(module) == nn.Sequential: - module.qconfig = None + if white_list is None: + white_list = DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST + + module_qconfig = qconfig_dict.get(type(module), qconfig_parent) + module_qconfig = qconfig_dict.get(prefix, module_qconfig) + module_qconfig = getattr(module, 'qconfig', module_qconfig) + if type(module) in white_list: + module.qconfig = module_qconfig for name, child in module.named_children(): module_prefix = prefix + '.' + name if prefix else name - _propagate_qconfig_helper(child, qconfig_dict, skip_list, module.qconfig, module_prefix) + _propagate_qconfig_helper(child, qconfig_dict, white_list, + module_qconfig, module_prefix) -# TODO(jerryzh): expose skip_list +# TODO(jerryzh): expose white_list def propagate_qconfig_(module, qconfig_dict=None): r"""Propagate qconfig through the module hierarchy and assign `qconfig` attribute on each leaf module Args: module: input module - qconfig_dict: dictionary that maps from name or type of submodule to quantization - configuration, qconfig applies to all submodules of a given - module unless qconfig for the submodules are specified (when the - submodule already has qconfig attribute) + qconfig_dict: dictionary that maps from name or type of submodule to + quantization configuration, qconfig applies to all submodules of a + given module unless qconfig for the submodules are specified (when + the submodule already has qconfig attribute) Return: None, module is modified inplace with qconfig attached @@ -109,12 +80,10 @@ def add_observer_(module): has a valid qconfig attribute. Args: - module: input module with qconfig attributes for all the leaf modules - that we want to quantize + module: input module with qconfig attributes for all the leaf modules that we want to quantize Return: - None, module is modified inplace with added observer modules and - forward_hooks + None, module is modified inplace with added observer modules and forward_hooks """ for child in module.children(): if type(child) == nnq.FloatFunctional: @@ -131,30 +100,6 @@ def add_observer_(module): module.add_module('observer', module.qconfig.activation()) module.register_forward_hook(_observer_forward_hook) -class QuantWrapper(nn.Module): - r"""A wrapper class that wraps the input module, adds QuantStub and - DeQuantStub and surround the call to module with call to quant and dequant - modules. - - This is used by the `quantization` utility functions to add the quant and - dequant modules, before `convert` function `QuantStub` will just be observer, - it observes the input tensor, after `convert`, `QuantStub` - will be swapped to `nnq.Quantize` which does actual quantization. Similarly - for `DeQuantStub`. - """ - def __init__(self, module): - super(QuantWrapper, self).__init__() - qconfig = module.qconfig if hasattr(module, 'qconfig') else None - self.add_module('quant', QuantStub(qconfig)) - self.add_module('dequant', DeQuantStub()) - self.add_module('module', module) - self.train(module.training) - - def forward(self, X): - X = self.quant(X) - X = self.module(X) - return self.dequant(X) - def add_quant_dequant(module): r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it @@ -205,44 +150,7 @@ def prepare(model, qconfig_dict=None, inplace=False): add_observer_(model) return model -# Map for swapping float module to quantized ones -DEFAULT_MODULE_MAPPING = { - nn.Linear: nnq.Linear, - nn.ReLU: nnq.ReLU, - nn.Conv2d: nnq.Conv2d, - QuantStub: nnq.Quantize, - DeQuantStub: nnq.DeQuantize, - # Wrapper Modules: - nnq.FloatFunctional: nnq.QFunctional, - # Intrinsic modules: - nni.ConvReLU2d: nniq.ConvReLU2d, - nni.LinearReLU: nniq.LinearReLU, - nniqat.ConvReLU2d: nniq.ConvReLU2d, - nniqat.LinearReLU: nniq.LinearReLU, - nniqat.ConvBn2d: nnq.Conv2d, - nniqat.ConvBnReLU2d: nniq.ConvReLU2d, - # QAT modules: - nnqat.Linear: nnq.Linear, - nnqat.Conv2d: nnq.Conv2d, -} - -# Map for swapping float module to qat modules -DEFAULT_QAT_MODULE_MAPPING = { - nn.Linear: nnqat.Linear, - nn.Conv2d: nnqat.Conv2d, - # Intrinsic modules: - nni.ConvBn2d: nniqat.ConvBn2d, - nni.ConvBnReLU2d: nniqat.ConvBnReLU2d, - nni.ConvReLU2d: nniqat.ConvReLU2d, - nni.LinearReLU: nniqat.LinearReLU -} - -DEFAULT_DYNAMIC_MODULE_MAPPING = { - nn.Linear: nnqd.Linear, - nn.LSTM: nnqd.LSTM, -} - -def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING, inplace=False): +def quantize(model, run_fn, run_args, mapping=None, inplace=False): r"""Converts a float model to quantized model. First it will prepare the model for calibration or training, then it calls @@ -261,7 +169,8 @@ def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING, inplace=Fa Return: Quantized model. """ - + if mapping is None: + mapping = DEFAULT_MODULE_MAPPING if not inplace: model = copy.deepcopy(model) model.eval() @@ -270,7 +179,8 @@ def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING, inplace=Fa convert(model, mapping, inplace=True) return model -def quantize_dynamic(model, qconfig_dict=None, dtype=torch.qint8, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING, inplace=False): +def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, + mapping=None, inplace=False): r"""Converts a float model to dynamic (i.e. weights-only) quantized model. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. @@ -278,28 +188,31 @@ def quantize_dynamic(model, qconfig_dict=None, dtype=torch.qint8, mapping=DEFAUL For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants. - Fine grained control is possible with `qconfig_dict` and `mapping` that act similarly to `quantize()`. - If `qconfig_dict` is provided, the `dtype` argument is ignored. + Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. + If `qconfig` is provided, the `dtype` argument is ignored. Args: module: input model - qconfig_dict: dictionary that maps from name or type of submodule to quantization - configuration, qconfig applies to all submodules of a given - module unless qconfig for the submodules are specified (when the - submodule already has qconfig attribute). Entries in the dictionary - need to be QConfigDynamic instances. + qconfig_spec: Either: + * A dictionary that maps from name or type of submodule to quantization + configuration, qconfig applies to all submodules of a given + module unless qconfig for the submodules are specified (when the + submodule already has qconfig attribute). Entries in the dictionary + need to be QConfigDynamic instances. + * A set of types and/or submodule names to apply dynamic quantization to, + in which case the `dtype` argument is used to specifiy the bit-width inplace: carry out model transformations in-place, the original module is mutated mapping: maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced """ - if qconfig_dict is None: + if qconfig_spec is None: if dtype == torch.qint8: - qconfig_dict = { + qconfig_spec = { nn.Linear : default_dynamic_qconfig, nn.LSTM : default_dynamic_qconfig, } elif dtype == torch.float16: - qconfig_dict = { + qconfig_spec = { # TODO: uncomment when float16 Linear support is added # nn.Linear : default_dynamic_qconfig, nn.LSTM : float16_dynamic_qconfig, @@ -307,11 +220,22 @@ def quantize_dynamic(model, qconfig_dict=None, dtype=torch.qint8, mapping=DEFAUL else: raise ValueError( "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype)) + elif isinstance(qconfig_spec, set): + if dtype is torch.qint8: + default_qconfig = default_dynamic_qconfig + elif dtype is torch.float16: + default_qconfig = float16_dynamic_qconfig + else: + raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) + qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) + + if mapping is None: + mapping = DEFAULT_DYNAMIC_MODULE_MAPPING if not inplace: model = copy.deepcopy(model) model.eval() - propagate_qconfig_(model, qconfig_dict) + propagate_qconfig_(model, qconfig_spec) convert(model, mapping, inplace=True) return model @@ -326,7 +250,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False): Args: model: input model run_fn: a function for evaluating the prepared model, can be a - function that simply runs the prepared model or a training loop + function that simply runs the prepared model or a training + loop run_args: positional arguments for `run_fn` Return: @@ -340,15 +265,17 @@ def quantize_qat(model, run_fn, run_args, inplace=False): convert(model, inplace=True) return model -def convert(module, mapping=DEFAULT_MODULE_MAPPING, inplace=False): +def convert(module, mapping=None, inplace=False): r"""Converts the float module with observers (where we can get quantization parameters) to a quantized module. Args: module: calibrated module with observers - mapping: a dictionary that maps from float module type to quantized - module type, can be overwrritten to allow swapping user defined Modules + mapping: a dictionary that maps from float module type to quantized module type, can + be overwrritten to allow swapping user defined Modules inplace: carry out model transformations in-place, the original module is mutated """ + if mapping is None: + mapping = DEFAULT_MODULE_MAPPING if not inplace: module = copy.deepcopy(module) reassign = {} diff --git a/torch/quantization/stubs.py b/torch/quantization/stubs.py new file mode 100644 index 0000000000000..7018ef1b836da --- /dev/null +++ b/torch/quantization/stubs.py @@ -0,0 +1,54 @@ + +from torch import nn + +class QuantStub(nn.Module): + r"""Quantize stub module, before calibration, this is same as an observer, + it will be swapped as `nnq.Quantize` in `convert`. + + Args: + qconfig: quantization configuration for the tensor, + if qconfig is not provided, we will get qconfig from parent modules + """ + def __init__(self, qconfig=None): + super(QuantStub, self).__init__() + if qconfig: + self.qconfig = qconfig + + def forward(self, x): + return x + + +class DeQuantStub(nn.Module): + r"""Dequantize stub module, before calibration, this is same as identity, + this will be swapped as `nnq.DeQuantize` in `convert`. + """ + def __init__(self): + super(DeQuantStub, self).__init__() + + def forward(self, x): + return x + + +class QuantWrapper(nn.Module): + r"""A wrapper class that wraps the input module, adds QuantStub and + DeQuantStub and surround the call to module with call to quant and dequant + modules. + + This is used by the `quantization` utility functions to add the quant and + dequant modules, before `convert` function `QuantStub` will just be observer, + it observes the input tensor, after `convert`, `QuantStub` + will be swapped to `nnq.Quantize` which does actual quantization. Similarly + for `DeQuantStub`. + """ + def __init__(self, module): + super(QuantWrapper, self).__init__() + qconfig = module.qconfig if hasattr(module, 'qconfig') else None + self.add_module('quant', QuantStub(qconfig)) + self.add_module('dequant', DeQuantStub()) + self.add_module('module', module) + self.train(module.training) + + def forward(self, X): + X = self.quant(X) + X = self.module(X) + return self.dequant(X) diff --git a/torch/serialization.py b/torch/serialization.py index 22ff981a3909d..ca5a43c7bd599 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -8,6 +8,7 @@ import tarfile import tempfile import warnings +import copyreg from contextlib import closing, contextmanager from ._utils import _import_dotted_name from ._six import string_classes as _string_classes @@ -429,6 +430,23 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args): f.close() +# Register pickling support for layout instances such as +# torch.sparse_coo, etc +def _get_layout(name): + """Get layout extension object from its string representation. + """ + cache = _get_layout.cache + if not cache: + for v in torch.__dict__.values(): + if isinstance(v, torch.layout): + cache[str(v)] = v + return cache[name] + + +_get_layout.cache = {} +copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) + + def _load(f, map_location, pickle_module, **pickle_load_args): deserialized_objects = {} diff --git a/torch/tensor.py b/torch/tensor.py index 6d10ce59264f4..4c45299700de1 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -80,6 +80,16 @@ def __reduce_ex__(self, proto): self.requires_grad, OrderedDict()) return (torch._utils._rebuild_qtensor, args) + elif self.is_sparse: + if self.layout == torch.sparse_coo: + args = (self.layout, + (self._indices(), + self._values(), + self.size())) + else: + raise NotImplementedError( + 'sparse tensor __reduce_ex__ for layout `%s`' % (self.layout)) + return (torch._utils._rebuild_sparse_tensor, args) else: args = (self.storage(), self.storage_offset(), @@ -414,6 +424,8 @@ def __hash__(self): return id(self) def __dir__(self): + if self.is_quantized: + warnings.warn('Only a small subset of methods are supported for quantized tensors.') tensor_methods = dir(self.__class__) tensor_methods.remove('volatile') # deprecated attrs = list(self.__dict__.keys()) diff --git a/torch/utils/__init__.py b/torch/utils/__init__.py index 14ce713dfccfa..8912b870e6923 100644 --- a/torch/utils/__init__.py +++ b/torch/utils/__init__.py @@ -1,3 +1,9 @@ from __future__ import absolute_import, division, print_function, unicode_literals from .throughput_benchmark import ThroughputBenchmark # noqa: F401 + +# Set the module for a given object for nicer printing +def set_module(obj, mod): + if not isinstance(mod, str): + raise TypeError("The mod argument should be a string") + obj.__module__ = mod diff --git a/tools/amd_build/pyHIPIFY/__init__.py b/torch/utils/hipify/__init__.py similarity index 100% rename from tools/amd_build/pyHIPIFY/__init__.py rename to torch/utils/hipify/__init__.py diff --git a/tools/amd_build/pyHIPIFY/constants.py b/torch/utils/hipify/constants.py similarity index 96% rename from tools/amd_build/pyHIPIFY/constants.py rename to torch/utils/hipify/constants.py index 1384e6c891008..b0bd4f9f0313b 100644 --- a/tools/amd_build/pyHIPIFY/constants.py +++ b/torch/utils/hipify/constants.py @@ -52,8 +52,9 @@ API_LAST = 42 API_FFT = 43 API_RTC = 44 +API_ROCTX = 45 -HIP_UNSUPPORTED = 43 +HIP_UNSUPPORTED = 46 API_PYTORCH = 1337 API_CAFFE2 = 1338 API_C10 = 1339 diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py new file mode 100644 index 0000000000000..54848b8888402 --- /dev/null +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -0,0 +1,8083 @@ +#!/usr/bin/env python3 +import collections + +from .constants import * + +""" Mapping of CUDA functions, include files, constants, and types to ROCm/HIP equivalents +This closely follows the implementation in hipify-clang +https://github.com/ROCm-Developer-Tools/HIP/blob/master/hipify-clang/src/CUDA2HipMap.cpp +and its structure. +There are different maps for fundamental names, include files, identifies, sparse, and +PyTorch specific translations. +Each of the entries in these maps translates a CUDA string to a tuple containing the +ROCm/HIP string, a type and API annotation and - optionally - an annotation if it is not +supported in ROCm/HIP yet. +""" + +# List of math functions that should be replaced inside device code only. +MATH_TRANSPILATIONS = collections.OrderedDict( + [ + ("std::max", ("::max")), + ("std::min", ("::min")), + ("std::ceil", ("::ceil")), + ("std::floor", ("::floor")), + ("std::exp", ("::exp")), + ("std::log", ("::log")), + ("std::pow", ("::pow")), + ("std::fabs", ("::fabs")), + ("std::fmod", ("::fmod")), + ("std::remainder", ("::remainder")), + ] +) + +CUDA_TYPE_NAME_MAP = collections.OrderedDict( + [ + ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), + ("cudaError_t", ("hipError_t", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ARRAY3D_DESCRIPTOR", + ("HIP_ARRAY3D_DESCRIPTOR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_ARRAY_DESCRIPTOR", ("HIP_ARRAY_DESCRIPTOR", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY2D", ("hip_Memcpy2D", CONV_TYPE, API_DRIVER)), + ("CUDA_MEMCPY3D", ("HIP_MEMCPY3D", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_MEMCPY3D_PEER", + ("HIP_MEMCPY3D_PEER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_POINTER_ATTRIBUTE_P2P_TOKENS", + ( + "HIP_POINTER_ATTRIBUTE_P2P_TOKENS", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_RESOURCE_DESC", + ("HIP_RESOURCE_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_RESOURCE_VIEW_DESC", + ("HIP_RESOURCE_VIEW_DESC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUipcEventHandle", + ("hipIpcEventHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUipcMemHandle", ("hipIpcMemHandle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUaddress_mode", ("hipAddress_mode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUarray_cubemap_face", + ("hipArray_cubemap_face", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUarray_format", ("hipArray_format", CONV_TYPE, API_DRIVER)), + ("CUcomputemode", ("hipComputemode", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmem_advise", ("hipMemAdvise", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUmem_range_attribute", + ("hipMemRangeAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUctx_flags", ("hipCctx_flags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUdevice", ("hipDevice_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute_enum", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUdevice_attribute", ("hipDeviceAttribute_t", CONV_TYPE, API_DRIVER)), + ("CUdeviceptr", ("hipDeviceptr_t", CONV_TYPE, API_DRIVER)), + ("CUarray_st", ("hipArray", CONV_TYPE, API_DRIVER)), + ("CUarray", ("hipArray *", CONV_TYPE, API_DRIVER)), + ("CUdevprop_st", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUdevprop", ("hipDeviceProp_t", CONV_TYPE, API_DRIVER)), + ("CUfunction", ("hipFunction_t", CONV_TYPE, API_DRIVER)), + ( + "CUgraphicsResource", + ("hipGraphicsResource_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmipmappedArray", + ("hipMipmappedArray_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUfunction_attribute_enum", + ("hipFuncAttribute_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsMapResourceFlags_enum", + ("hipGraphicsMapFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUgraphicsRegisterFlags_enum", + ("hipGraphicsRegisterFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUoccupancy_flags_enum", + ("hipOccupancyFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUfunc_cache_enum", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUfunc_cache", ("hipFuncCache", CONV_TYPE, API_DRIVER)), + ("CUipcMem_flags", ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUipcMem_flags_enum", + ("hipIpcMemFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_cacheMode", ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_cacheMode_enum", + ("hipJitCacheMode", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_fallback", ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjit_fallback_enum", + ("hipJitFallback", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUjit_option", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_option_enum", ("hipJitOption", CONV_JIT, API_DRIVER)), + ("CUjit_target", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjit_target_enum", ("hipJitTarget", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUjitInputType", ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUjitInputType_enum", + ("hipJitInputType", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUlimit", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ("CUlimit_enum", ("hipLimit_t", CONV_TYPE, API_DRIVER)), + ( + "CUmemAttach_flags", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUmemAttach_flags_enum", + ("hipMemAttachFlags_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUmemorytype", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUmemorytype_enum", ("hipMemType_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ("CUresourcetype", ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUresourcetype_enum", + ("hipResourceType", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUresourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUresourceViewFormat_enum", ("hipResourceViewFormat", CONV_TEX, API_DRIVER)), + ("CUsharedconfig", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUsharedconfig_enum", ("hipSharedMemConfig", CONV_TYPE, API_DRIVER)), + ("CUcontext", ("hipCtx_t", CONV_TYPE, API_DRIVER)), + ("CUmodule", ("hipModule_t", CONV_TYPE, API_DRIVER)), + ("CUstream", ("hipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstream_st", ("ihipStream_t", CONV_TYPE, API_DRIVER)), + ("CUstreamCallback", ("hipStreamCallback_t", CONV_TYPE, API_DRIVER)), + ("CUsurfObject", ("hipSurfaceObject", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUsurfref", + ("hipSurfaceReference_t", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUtexObject", ("hipTextureObject_t", CONV_TYPE, API_DRIVER)), + ("CUtexref", ("textureReference", CONV_TYPE, API_DRIVER)), + ("CUstream_flags", ("hipStreamFlags", CONV_TYPE, API_DRIVER)), + ( + "CUstreamWaitValue_flags", + ("hipStreamWaitValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamWriteValue_flags", + ("hipStreamWriteValueFlags", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUstreamBatchMemOpType", + ("hipStreamBatchMemOpType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUdevice_P2PAttribute", + ("hipDeviceP2PAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUevent", ("hipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_st", ("ihipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_flags", ("hipEventFlags", CONV_EVENT, API_DRIVER, HIP_UNSUPPORTED)), + ("CUfilter_mode", ("hipTextureFilterMode", CONV_TEX, API_DRIVER)), + ("CUGLDeviceList", ("hipGLDeviceList", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("CUGLmap_flags", ("hipGLMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUd3d9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9map_flags", + ("hipD3D9MapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d9register_flags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10map_flags", + ("hipD3D10MapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d10register_flags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUd3d11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection_st", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUeglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType_t", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "libraryPropertyType", + ("hipLibraryPropertyType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamCallback_t", ("hipStreamCallback_t", CONV_TYPE, API_RUNTIME)), + ("cudaArray", ("hipArray", CONV_MEM, API_RUNTIME)), + ("cudaArray_t", ("hipArray_t", CONV_MEM, API_RUNTIME)), + ("cudaArray_const_t", ("hipArray_const_t", CONV_MEM, API_RUNTIME)), + ("cudaMipmappedArray_t", ("hipMipmappedArray_t", CONV_MEM, API_RUNTIME)), + ( + "cudaMipmappedArray_const_t", + ("hipMipmappedArray_const_t", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayDefault", ("hipArrayDefault", CONV_MEM, API_RUNTIME)), + ("cudaArrayLayered", ("hipArrayLayered", CONV_MEM, API_RUNTIME)), + ( + "cudaArraySurfaceLoadStore", + ("hipArraySurfaceLoadStore", CONV_MEM, API_RUNTIME), + ), + ("cudaArrayCubemap", ("hipArrayCubemap", CONV_MEM, API_RUNTIME)), + ("cudaArrayTextureGather", ("hipArrayTextureGather", CONV_MEM, API_RUNTIME)), + ("cudaMemoryAdvise", ("hipMemAdvise", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeAttribute", + ("hipMemRangeAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyKind", ("hipMemcpyKind", CONV_MEM, API_RUNTIME)), + ("cudaMemoryType", ("hipMemoryType", CONV_MEM, API_RUNTIME)), + ("cudaExtent", ("hipExtent", CONV_MEM, API_RUNTIME)), + ("cudaPitchedPtr", ("hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("cudaPos", ("hipPos", CONV_MEM, API_RUNTIME)), + ("cudaEvent_t", ("hipEvent_t", CONV_TYPE, API_RUNTIME)), + ("cudaStream_t", ("hipStream_t", CONV_TYPE, API_RUNTIME)), + ("cudaPointerAttributes", ("hipPointerAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceAttr", ("hipDeviceAttribute_t", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceProp", ("hipDeviceProp_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceP2PAttr", + ("hipDeviceP2PAttribute", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeMode", + ("hipComputeMode", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaFuncCache", ("hipFuncCache_t", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncAttributes", + ("hipFuncAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSharedMemConfig", ("hipSharedMemConfig", CONV_TYPE, API_RUNTIME)), + ("cudaLimit", ("hipLimit_t", CONV_TYPE, API_RUNTIME)), + ("cudaOutputMode", ("hipOutputMode", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaTextureReadMode", ("hipTextureReadMode", CONV_TEX, API_RUNTIME)), + ("cudaTextureFilterMode", ("hipTextureFilterMode", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatKind", ("hipChannelFormatKind", CONV_TEX, API_RUNTIME)), + ("cudaChannelFormatDesc", ("hipChannelFormatDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceDesc", ("hipResourceDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewDesc", ("hipResourceViewDesc", CONV_TEX, API_RUNTIME)), + ("cudaTextureDesc", ("hipTextureDesc", CONV_TEX, API_RUNTIME)), + ( + "surfaceReference", + ("hipSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureObject_t", ("hipTextureObject_t", CONV_TEX, API_RUNTIME)), + ("cudaResourceType", ("hipResourceType", CONV_TEX, API_RUNTIME)), + ("cudaResourceViewFormat", ("hipResourceViewFormat", CONV_TEX, API_RUNTIME)), + ("cudaTextureAddressMode", ("hipTextureAddressMode", CONV_TEX, API_RUNTIME)), + ( + "cudaSurfaceBoundaryMode", + ("hipSurfaceBoundaryMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSurfaceFormatMode", + ("hipSurfaceFormatMode", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaTextureType1D", ("hipTextureType1D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType2D", ("hipTextureType2D", CONV_TEX, API_RUNTIME)), + ("cudaTextureType3D", ("hipTextureType3D", CONV_TEX, API_RUNTIME)), + ("cudaTextureTypeCubemap", ("hipTextureTypeCubemap", CONV_TEX, API_RUNTIME)), + ( + "cudaTextureType1DLayered", + ("hipTextureType1DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureType2DLayered", + ("hipTextureType2DLayered", CONV_TEX, API_RUNTIME), + ), + ( + "cudaTextureTypeCubemapLayered", + ("hipTextureTypeCubemapLayered", CONV_TEX, API_RUNTIME), + ), + ("cudaIpcEventHandle_t", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcEventHandle_st", ("hipIpcEventHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_t", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ("cudaIpcMemHandle_st", ("hipIpcMemHandle_t", CONV_TYPE, API_RUNTIME)), + ( + "cudaGraphicsCubeFace", + ("hipGraphicsCubeFace", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlags", + ("hipGraphicsMapFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsRegisterFlags", + ("hipGraphicsRegisterFlags", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceList", + ("hipGLDeviceList", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaGLMapFlags", ("hipGLMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaD3D9DeviceList", + ("hipD3D9DeviceList", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlags", + ("hipD3D9RegisterFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceList", + ("hipd3d10DeviceList", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapFlags", + ("hipD3D10MapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlags", + ("hipD3D10RegisterFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceList", + ("hipd3d11DeviceList", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEglStreamConnection", + ("hipEglStreamConnection", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cublasHandle_t", ("rocblas_handle", CONV_TYPE, API_BLAS)), + ("cublasOperation_t", ("rocblas_operation", CONV_TYPE, API_BLAS)), + ("cublasStatus_t", ("rocblas_status", CONV_TYPE, API_BLAS)), + ("cublasFillMode_t", ("rocblas_fill", CONV_TYPE, API_BLAS)), + ("cublasDiagType_t", ("rocblas_diagonal", CONV_TYPE, API_BLAS)), + ("cublasSideMode_t", ("rocblas_side", CONV_TYPE, API_BLAS)), + ("cublasPointerMode_t", ("rocblas_pointer_mode", CONV_TYPE, API_BLAS)), + ( + "cublasAtomicsMode_t", + ("rocblas_atomics_mode", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDataType_t", + ("rocblas_data_type", CONV_TYPE, API_BLAS, HIP_UNSUPPORTED), + ), + ("curandStatus", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandStatus_t", ("hiprandStatus_t", CONV_TYPE, API_RAND)), + ("curandRngType", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandRngType_t", ("hiprandRngType_t", CONV_TYPE, API_RAND)), + ("curandGenerator_st", ("hiprandGenerator_st", CONV_TYPE, API_RAND)), + ("curandGenerator_t", ("hiprandGenerator_t", CONV_TYPE, API_RAND)), + ( + "curandDirectionVectorSet", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDirectionVectorSet_t", + ("hiprandDirectionVectorSet_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandOrdering", ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandOrdering_t", + ("hiprandOrdering_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_st", + ("hiprandDistribution_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistribution_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2V_t", + ("hiprandDistribution_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_st", + ("hiprandDistributionShift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionShift_t", + ("hiprandDistributionShift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_st", + ("hiprandDistributionM2Shift_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDistributionM2Shift_t", + ("hiprandDistributionM2Shift_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_st", + ("hiprandHistogramM2_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2_t", + ("hiprandHistogramM2_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_st", + ("hiprandHistogramM2K_st", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandHistogramM2K_t", + ("hiprandHistogramM2K_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandDiscreteDistribution_st", + ("hiprandDiscreteDistribution_st", CONV_TYPE, API_RAND), + ), + ( + "curandDiscreteDistribution_t", + ("hiprandDiscreteDistribution_t", CONV_TYPE, API_RAND), + ), + ("curandMethod", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ("curandMethod_t", ("hiprandMethod_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED)), + ( + "curandDirectionVectors32_t", + ("hiprandDirectionVectors32_t", CONV_TYPE, API_RAND), + ), + ( + "curandDirectionVectors64_t", + ("hiprandDirectionVectors64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateMtgp32_t", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ("curandStateMtgp32", ("hiprandStateMtgp32_t", CONV_TYPE, API_RAND)), + ( + "curandStateScrambledSobol64_t", + ("hiprandStateScrambledSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateSobol64_t", + ("hiprandStateSobol64_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandStateScrambledSobol32_t", + ("hiprandStateScrambledSobol32_t", CONV_TYPE, API_RAND, HIP_UNSUPPORTED), + ), + ("curandStateSobol32_t", ("hiprandStateSobol32_t", CONV_TYPE, API_RAND)), + ("curandStateMRG32k3a_t", ("hiprandStateMRG32k3a_t", CONV_TYPE, API_RAND)), + ( + "curandStatePhilox4_32_10_t", + ("hiprandStatePhilox4_32_10_t", CONV_TYPE, API_RAND), + ), + ("curandStateXORWOW_t", ("hiprandStateXORWOW_t", CONV_TYPE, API_RAND)), + ("curandState_t", ("hiprandState_t", CONV_TYPE, API_RAND)), + ("curandState", ("hiprandState_t", CONV_TYPE, API_RAND)), + ] +) + +CUDA_INCLUDE_MAP = collections.OrderedDict( + [ + # since pytorch uses "\b{pattern}\b" as the actual re pattern, + # patterns listed here have to begin and end with alnum chars + ( + "include " to differentiate + ("", ("", CONV_INCLUDE, API_RUNTIME)), + ("nvrtc.h", ("hip/hiprtc.h", CONV_INCLUDE, API_RTC)), + ("thrust/system/cuda", ("thrust/system/hip", CONV_INCLUDE, API_BLAS)), + ("cub/util_allocator.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/cub.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/block/block_load.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("cub/device/device_scan.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), + ("nvToolsExt.h", ("roctx.h", CONV_INCLUDE, API_ROCTX)), + ] +) + +CUDA_IDENTIFIER_MAP = collections.OrderedDict( + [ + ("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)), + ( + "CUDA_ERROR_INVALID_CONTEXT", + ("hipErrorInvalidContext", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", + ("hipErrorContextAlreadyCurrent", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_ARRAY_IS_MAPPED", + ("hipErrorArrayIsMapped", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_ALREADY_MAPPED", ("hipErrorAlreadyMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_ALREADY_ACQUIRED", + ("hipErrorAlreadyAcquired", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_NOT_MAPPED", ("hipErrorNotMapped", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_NOT_MAPPED_AS_ARRAY", + ("hipErrorNotMappedAsArray", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_NOT_MAPPED_AS_POINTER", + ("hipErrorNotMappedAsPointer", CONV_TYPE, API_DRIVER), + ), + ( + "CUDA_ERROR_CONTEXT_ALREADY_IN_USE", + ("hipErrorContextAlreadyInUse", CONV_TYPE, API_DRIVER), + ), + ("CUDA_ERROR_INVALID_SOURCE", ("hipErrorInvalidSource", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_FILE_NOT_FOUND", ("hipErrorFileNotFound", CONV_TYPE, API_DRIVER)), + ("CUDA_ERROR_NOT_FOUND", ("hipErrorNotFound", CONV_TYPE, API_DRIVER)), + ( + "CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING", + ( + "hipErrorLaunchIncompatibleTexturing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE", + ("hipErrorPrimaryContextActive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_CONTEXT_IS_DESTROYED", + ("hipErrorContextIsDestroyed", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_PERMITTED", + ("hipErrorNotPermitted", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NOT_SUPPORTED", + ("hipErrorNotSupported", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMissingConfiguration", + ("hipErrorMissingConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorPriorLaunchFailure", + ("hipErrorPriorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDeviceFunction", + ("hipErrorInvalidDeviceFunction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidConfiguration", + ("hipErrorInvalidConfiguration", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPitchValue", + ("hipErrorInvalidPitchValue", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidSymbol", + ("hipErrorInvalidSymbol", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidHostPointer", + ("hipErrorInvalidHostPointer", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidDevicePointer", + ("hipErrorInvalidDevicePointer", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaErrorInvalidTexture", + ("hipErrorInvalidTexture", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidTextureBinding", + ("hipErrorInvalidTextureBinding", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidChannelDescriptor", + ( + "hipErrorInvalidChannelDescriptor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorInvalidMemcpyDirection", + ("hipErrorInvalidMemcpyDirection", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAddressOfConstant", + ("hipErrorAddressOfConstant", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureFetchFailed", + ("hipErrorTextureFetchFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTextureNotBound", + ("hipErrorTextureNotBound", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSynchronizationError", + ("hipErrorSynchronizationError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidFilterSetting", + ("hipErrorInvalidFilterSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidNormSetting", + ("hipErrorInvalidNormSetting", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMixedDeviceExecution", + ("hipErrorMixedDeviceExecution", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotYetImplemented", + ("hipErrorNotYetImplemented", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMemoryValueTooLarge", + ("hipErrorMemoryValueTooLarge", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInsufficientDriver", + ("hipErrorInsufficientDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSetOnActiveProcess", + ("hipErrorSetOnActiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidSurface", + ("hipErrorInvalidSurface", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateVariableName", + ("hipErrorDuplicateVariableName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateTextureName", + ("hipErrorDuplicateTextureName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDuplicateSurfaceName", + ("hipErrorDuplicateSurfaceName", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorDevicesUnavailable", + ("hipErrorDevicesUnavailable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIncompatibleDriverContext", + ( + "hipErrorIncompatibleDriverContext", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorDeviceAlreadyInUse", + ("hipErrorDeviceAlreadyInUse", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchMaxDepthExceeded", + ("hipErrorLaunchMaxDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedTex", + ("hipErrorLaunchFileScopedTex", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFileScopedSurf", + ("hipErrorLaunchFileScopedSurf", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorSyncDepthExceeded", + ("hipErrorSyncDepthExceeded", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchPendingCountExceeded", + ( + "hipErrorLaunchPendingCountExceeded", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaErrorNotPermitted", + ("hipErrorNotPermitted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNotSupported", + ("hipErrorNotSupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorStartupFailure", + ("hipErrorStartupFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaErrorApiFailureBase", + ("hipErrorApiFailureBase", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_SUCCESS", ("hipSuccess", CONV_TYPE, API_DRIVER)), + ("cudaSuccess", ("hipSuccess", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_VALUE", ("hipErrorInvalidValue", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidValue", ("hipErrorInvalidValue", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_OUT_OF_MEMORY", + ("hipErrorMemoryAllocation", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorMemoryAllocation", + ("hipErrorMemoryAllocation", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_NOT_INITIALIZED", + ("hipErrorNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInitializationError", + ("hipErrorInitializationError", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_DEINITIALIZED", ("hipErrorDeinitialized", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorCudartUnloading", + ("hipErrorDeinitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_DISABLED", + ("hipErrorProfilerDisabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerDisabled", + ("hipErrorProfilerDisabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_NOT_INITIALIZED", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerNotInitialized", + ("hipErrorProfilerNotInitialized", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STARTED", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStarted", + ("hipErrorProfilerAlreadyStarted", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PROFILER_ALREADY_STOPPED", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorProfilerAlreadyStopped", + ("hipErrorProfilerAlreadyStopped", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_NO_DEVICE", ("hipErrorNoDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorNoDevice", ("hipErrorNoDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_DEVICE", ("hipErrorInvalidDevice", CONV_TYPE, API_DRIVER)), + ("cudaErrorInvalidDevice", ("hipErrorInvalidDevice", CONV_TYPE, API_RUNTIME)), + ("CUDA_ERROR_INVALID_IMAGE", ("hipErrorInvalidImage", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorInvalidKernelImage", + ("hipErrorInvalidImage", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_MAP_FAILED", ("hipErrorMapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorMapBufferObjectFailed", + ("hipErrorMapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("CUDA_ERROR_UNMAP_FAILED", ("hipErrorUnmapFailed", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorUnmapBufferObjectFailed", + ("hipErrorUnmapFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NO_BINARY_FOR_GPU", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorNoKernelImageForDevice", + ("hipErrorNoBinaryForGpu", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ECC_UNCORRECTABLE", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorECCUncorrectable", + ("hipErrorECCNotCorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNSUPPORTED_LIMIT", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorUnsupportedLimit", + ("hipErrorUnsupportedLimit", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_UNSUPPORTED", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessUnsupported", + ("hipErrorPeerAccessUnsupported", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PTX", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidPtx", + ("hipErrorInvalidKernelFile", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_GRAPHICS_CONTEXT", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidGraphicsContext", + ("hipErrorInvalidGraphicsContext", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_NVLINK_UNCORRECTABLE", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorNvlinkUncorrectable", + ("hipErrorNvlinkUncorrectable", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND", + ("hipErrorSharedObjectSymbolNotFound", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectSymbolNotFound", + ( + "hipErrorSharedObjectSymbolNotFound", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "CUDA_ERROR_SHARED_OBJECT_INIT_FAILED", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorSharedObjectInitFailed", + ("hipErrorSharedObjectInitFailed", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_OPERATING_SYSTEM", + ("hipErrorOperatingSystem", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorOperatingSystem", + ("hipErrorOperatingSystem", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_HANDLE", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorInvalidResourceHandle", + ("hipErrorInvalidResourceHandle", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_NOT_READY", ("hipErrorNotReady", CONV_TYPE, API_DRIVER)), + ("cudaErrorNotReady", ("hipErrorNotReady", CONV_TYPE, API_RUNTIME)), + ( + "CUDA_ERROR_ILLEGAL_ADDRESS", + ("hipErrorIllegalAddress", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorIllegalAddress", + ("hipErrorIllegalAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorLaunchOutOfResources", + ("hipErrorLaunchOutOfResources", CONV_TYPE, API_RUNTIME), + ), + ("CUDA_ERROR_LAUNCH_TIMEOUT", ("hipErrorLaunchTimeOut", CONV_TYPE, API_DRIVER)), + ( + "cudaErrorLaunchTimeout", + ("hipErrorLaunchTimeOut", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessAlreadyEnabled", + ("hipErrorPeerAccessAlreadyEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_PEER_ACCESS_NOT_ENABLED", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorPeerAccessNotEnabled", + ("hipErrorPeerAccessNotEnabled", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_ASSERT", + ("hipErrorAssert", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorAssert", + ("hipErrorAssert", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_TOO_MANY_PEERS", + ("hipErrorTooManyPeers", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorTooManyPeers", + ("hipErrorTooManyPeers", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryAlreadyRegistered", + ("hipErrorHostMemoryAlreadyRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_DRIVER), + ), + ( + "cudaErrorHostMemoryNotRegistered", + ("hipErrorHostMemoryNotRegistered", CONV_TYPE, API_RUNTIME), + ), + ( + "CUDA_ERROR_HARDWARE_STACK_ERROR", + ("hipErrorHardwareStackError", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorHardwareStackError", + ("hipErrorHardwareStackError", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_ILLEGAL_INSTRUCTION", + ("hipErrorIllegalInstruction", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorIllegalInstruction", + ("hipErrorIllegalInstruction", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_MISALIGNED_ADDRESS", + ("hipErrorMisalignedAddress", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorMisalignedAddress", + ("hipErrorMisalignedAddress", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_ADDRESS_SPACE", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidAddressSpace", + ("hipErrorInvalidAddressSpace", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_INVALID_PC", + ("hipErrorInvalidPc", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorInvalidPc", + ("hipErrorInvalidPc", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_LAUNCH_FAILED", + ("hipErrorLaunchFailure", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cudaErrorLaunchFailure", + ("hipErrorLaunchFailure", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "CUDA_ERROR_UNKNOWN", + ("hipErrorUnknown", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cudaErrorUnknown", ("hipErrorUnknown", CONV_TYPE, API_RUNTIME)), + ( + "CU_TR_ADDRESS_MODE_WRAP", + ("HIP_TR_ADDRESS_MODE_WRAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_CLAMP", + ("HIP_TR_ADDRESS_MODE_CLAMP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_MIRROR", + ("HIP_TR_ADDRESS_MODE_MIRROR", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TR_ADDRESS_MODE_BORDER", + ("HIP_TR_ADDRESS_MODE_BORDER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_X", + ("HIP_CUBEMAP_FACE_POSITIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_X", + ("HIP_CUBEMAP_FACE_NEGATIVE_X", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Y", + ("HIP_CUBEMAP_FACE_POSITIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Y", + ("HIP_CUBEMAP_FACE_NEGATIVE_Y", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_POSITIVE_Z", + ("HIP_CUBEMAP_FACE_POSITIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CUBEMAP_FACE_NEGATIVE_Z", + ("HIP_CUBEMAP_FACE_NEGATIVE_Z", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT8", + ("HIP_AD_FORMAT_UNSIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT16", + ("HIP_AD_FORMAT_UNSIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_UNSIGNED_INT32", + ("HIP_AD_FORMAT_UNSIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT8", + ("HIP_AD_FORMAT_SIGNED_INT8", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT16", + ("HIP_AD_FORMAT_SIGNED_INT16", CONV_TYPE, API_DRIVER), + ), + ( + "CU_AD_FORMAT_SIGNED_INT32", + ("HIP_AD_FORMAT_SIGNED_INT32", CONV_TYPE, API_DRIVER), + ), + ("CU_AD_FORMAT_HALF", ("HIP_AD_FORMAT_HALF", CONV_TYPE, API_DRIVER)), + ("CU_AD_FORMAT_FLOAT", ("HIP_AD_FORMAT_FLOAT", CONV_TYPE, API_DRIVER)), + ( + "CU_COMPUTEMODE_DEFAULT", + ("hipComputeModeDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE", + ("hipComputeModeExclusive", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_PROHIBITED", + ("hipComputeModeProhibited", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_COMPUTEMODE_EXCLUSIVE_PROCESS", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_READ_MOSTLY", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_READ_MOSTLY", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_SET_PREFERRED_LOCATION", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_ADVISE_SET_ACCESSED_BY", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ADVISE_UNSET_ACCESSED_BY", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_CTX_SCHED_AUTO", + ("HIP_CTX_SCHED_AUTO", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_SPIN", + ("HIP_CTX_SCHED_SPIN", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_YIELD", + ("HIP_CTX_SCHED_YIELD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_BLOCKING_SYNC", + ("HIP_CTX_SCHED_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_BLOCKING_SYNC", + ("HIP_CTX_BLOCKING_SYNC", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_SCHED_MASK", + ("HIP_CTX_SCHED_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_MAP_HOST", + ("HIP_CTX_MAP_HOST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_LMEM_RESIZE_TO_MAX", + ("HIP_CTX_LMEM_RESIZE_TO_MAX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_CTX_FLAGS_MASK", + ("HIP_CTX_FLAGS_MASK", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_POINTER", + ("HIP_LAUNCH_PARAM_BUFFER_POINTER", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LAUNCH_PARAM_BUFFER_SIZE", + ("HIP_LAUNCH_PARAM_BUFFER_SIZE", CONV_TYPE, API_DRIVER), + ), + ("CU_LAUNCH_PARAM_END", ("HIP_LAUNCH_PARAM_END", CONV_TYPE, API_DRIVER)), + ( + "CU_IPC_HANDLE_SIZE", + ("HIP_LAUNCH_PARAM_END", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_DEVICEMAP", + ("HIP_MEMHOSTALLOC_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_PORTABLE", + ("HIP_MEMHOSTALLOC_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTALLOC_WRITECOMBINED", + ("HIP_MEMHOSTALLOC_WRITECOMBINED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_DEVICEMAP", + ("HIP_MEMHOSTREGISTER_DEVICEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_IOMEMORY", + ("HIP_MEMHOSTREGISTER_IOMEMORY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMHOSTREGISTER_PORTABLE", + ("HIP_MEMHOSTREGISTER_PORTABLE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PARAM_TR_DEFAULT", + ("HIP_PARAM_TR_DEFAULT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_LEGACY", + ("HIP_STREAM_LEGACY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_PER_THREAD", + ("HIP_STREAM_PER_THREAD", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSA_OVERRIDE_FORMAT", + ("HIP_TRSA_OVERRIDE_FORMAT", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_NORMALIZED_COORDINATES", + ("HIP_TRSF_NORMALIZED_COORDINATES", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TRSF_READ_AS_INTEGER", + ("HIP_TRSF_READ_AS_INTEGER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TRSF_SRGB", ("HIP_TRSF_SRGB", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CUDA_ARRAY3D_2DARRAY", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_CUBEMAP", + ("HIP_ARRAY3D_CUBEMAP", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_DEPTH_TEXTURE", + ("HIP_ARRAY3D_DEPTH_TEXTURE", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_LAYERED", + ("HIP_ARRAY3D_LAYERED", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_SURFACE_LDST", + ("HIP_ARRAY3D_SURFACE_LDST", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CUDA_ARRAY3D_TEXTURE_GATHER", + ("HIP_ARRAY3D_TEXTURE_GATHER", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipDeviceAttributeMaxThreadsPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK", + ( + "hipDeviceAttributeMaxSharedMemoryPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY", + ( + "hipDeviceAttributeTotalConstantMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_WARP_SIZE", + ("hipDeviceAttributeWarpSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_PITCH", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK", + ( + "hipDeviceAttributeMaxRegistersPerBlock", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CLOCK_RATE", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GPU_OVERLAP", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT", + ( + "hipDeviceAttributeMultiprocessorCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_INTEGRATED", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_MODE", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ECC_ENABLED", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_BUS_ID", + ("hipDeviceAttributePciBusId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_TCC_DRIVER", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE", + ( + "hipDeviceAttributeMemoryClockRate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER", + ( + "hipDeviceAttributeCanTex2DGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY", + ("hipDeviceAttributeManagedMemory", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_DRIVER), + ), + ( + "CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID", + ( + "hipDeviceAttributeMultiGpuBoardGroupId", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_ATTRIBUTE_MAX", + ("hipDeviceAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_CONTEXT", + ("hipPointerAttributeContext", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_MEMORY_TYPE", + ("hipPointerAttributeMemoryType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_DEVICE_POINTER", + ( + "hipPointerAttributeDevicePointer", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_POINTER_ATTRIBUTE_HOST_POINTER", + ("hipPointerAttributeHostPointer", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_P2P_TOKENS", + ("hipPointerAttributeP2pTokens", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_SYNC_MEMOPS", + ("hipPointerAttributeSyncMemops", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_BUFFER_ID", + ("hipPointerAttributeBufferId", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_POINTER_ATTRIBUTE_IS_MANAGED", + ("hipPointerAttributeIsManaged", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK", + ( + "hipFuncAttributeMaxThreadsPerBlocks", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES", + ("hipFuncAttributeSharedSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES", + ("hipFuncAttributeConstSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES", + ("hipFuncAttributeLocalSizeBytes", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_NUM_REGS", + ("hipFuncAttributeNumRegs", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_PTX_VERSION", + ("hipFuncAttributePtxVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_BINARY_VERSION", + ("hipFuncAttributeBinaryVersion", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_CACHE_MODE_CA", + ("hipFuncAttributeCacheModeCA", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_FUNC_ATTRIBUTE_MAX", + ("hipFuncAttributeMax", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE", + ("hipGraphicsMapFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY", + ("hipGraphicsMapFlagsReadOnly", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ("hipGraphicsMapFlagsWriteDiscard", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_NONE", + ("hipGraphicsRegisterFlagsNone", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_OCCUPANCY_DEFAULT", + ("hipOccupancyDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_FUNC_CACHE_PREFER_NONE", + ("hipFuncCachePreferNone", CONV_CACHE, API_DRIVER), + ), + ( + "CU_FUNC_CACHE_PREFER_SHARED", + ("hipFuncCachePreferShared", CONV_CACHE, API_DRIVER), + ), + ("CU_FUNC_CACHE_PREFER_L1", ("hipFuncCachePreferL1", CONV_CACHE, API_DRIVER)), + ( + "CU_FUNC_CACHE_PREFER_EQUAL", + ("hipFuncCachePreferEqual", CONV_CACHE, API_DRIVER), + ), + ( + "CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CUDA_IPC_HANDLE_SIZE", ("HIP_IPC_HANDLE_SIZE", CONV_TYPE, API_DRIVER)), + ( + "CU_JIT_CACHE_OPTION_NONE", + ("hipJitCacheModeOptionNone", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CG", + ("hipJitCacheModeOptionCG", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_CACHE_OPTION_CA", + ("hipJitCacheModeOptionCA", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_PTX", + ("hipJitFallbackPreferPtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_PREFER_BINARY", + ("hipJitFallbackPreferBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_JIT_MAX_REGISTERS", ("hipJitOptionMaxRegisters", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_THREADS_PER_BLOCK", + ("hipJitOptionThreadsPerBlock", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_WALL_TIME", ("hipJitOptionWallTime", CONV_JIT, API_DRIVER)), + ("CU_JIT_INFO_LOG_BUFFER", ("hipJitOptionInfoLogBuffer", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionInfoLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER", + ("hipJitOptionErrorLogBuffer", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES", + ("hipJitOptionErrorLogBufferSizeBytes", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_OPTIMIZATION_LEVEL", + ("hipJitOptionOptimizationLevel", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_TARGET_FROM_CUCONTEXT", + ("hipJitOptionTargetFromContext", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_TARGET", ("hipJitOptionTarget", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_FALLBACK_STRATEGY", + ("hipJitOptionFallbackStrategy", CONV_JIT, API_DRIVER), + ), + ( + "CU_JIT_GENERATE_DEBUG_INFO", + ("hipJitOptionGenerateDebugInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_LOG_VERBOSE", ("hipJitOptionLogVerbose", CONV_JIT, API_DRIVER)), + ( + "CU_JIT_GENERATE_LINE_INFO", + ("hipJitOptionGenerateLineInfo", CONV_JIT, API_DRIVER), + ), + ("CU_JIT_CACHE_MODE", ("hipJitOptionCacheMode", CONV_JIT, API_DRIVER)), + ("CU_JIT_NEW_SM3X_OPT", ("hipJitOptionSm3xOpt", CONV_JIT, API_DRIVER)), + ("CU_JIT_FAST_COMPILE", ("hipJitOptionFastCompile", CONV_JIT, API_DRIVER)), + ("CU_JIT_NUM_OPTIONS", ("hipJitOptionNumOptions", CONV_JIT, API_DRIVER)), + ( + "CU_TARGET_COMPUTE_10", + ("hipJitTargetCompute10", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_11", + ("hipJitTargetCompute11", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_12", + ("hipJitTargetCompute12", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_13", + ("hipJitTargetCompute13", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_20", + ("hipJitTargetCompute20", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_21", + ("hipJitTargetCompute21", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_30", + ("hipJitTargetCompute30", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_32", + ("hipJitTargetCompute32", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_35", + ("hipJitTargetCompute35", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_37", + ("hipJitTargetCompute37", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_50", + ("hipJitTargetCompute50", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_52", + ("hipJitTargetCompute52", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_53", + ("hipJitTargetCompute53", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_60", + ("hipJitTargetCompute60", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_61", + ("hipJitTargetCompute61", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_TARGET_COMPUTE_62", + ("hipJitTargetCompute62", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_CUBIN", + ("hipJitInputTypeBin", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_PTX", + ("hipJitInputTypePtx", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_FATBINARY", + ("hipJitInputTypeFatBinary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_OBJECT", + ("hipJitInputTypeObject", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_INPUT_LIBRARY", + ("hipJitInputTypeLibrary", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_JIT_NUM_INPUT_TYPES", + ("hipJitInputTypeNumInputTypes", CONV_JIT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_PRINTF_FIFO_SIZE", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_MALLOC_HEAP_SIZE", + ("hipLimitMallocHeapSize", CONV_TYPE, API_DRIVER), + ), + ( + "CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_LIMIT_STACK_SIZE", + ("hipLimitStackSize", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_GLOBAL", + ("hipMemAttachGlobal", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_HOST", + ("hipMemAttachHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEM_ATTACH_SINGLE", + ("hipMemAttachSingle", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_HOST", + ("hipMemTypeHost", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_DEVICE", + ("hipMemTypeDevice", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_ARRAY", + ("hipMemTypeArray", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_MEMORYTYPE_UNIFIED", + ("hipMemTypeUnified", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_ARRAY", + ("hipResourceTypeArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_MIPMAPPED_ARRAY", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_LINEAR", + ("hipResourceTypeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_RESOURCE_TYPE_PITCH2D", + ("hipResourceTypePitch2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_RES_VIEW_FORMAT_NONE", ("hipResViewFormatNone", CONV_TEX, API_DRIVER)), + ( + "CU_RES_VIEW_FORMAT_UINT_1X8", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X8", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X8", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X8", + ("hipResViewFormatSignedChar1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X8", + ("hipResViewFormatSignedChar2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X8", + ("hipResViewFormatSignedChar4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X16", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X16", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X16", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X16", + ("hipResViewFormatSignedShort1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X16", + ("hipResViewFormatSignedShort2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X16", + ("hipResViewFormatSignedShort4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_1X32", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_2X32", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UINT_4X32", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_1X32", + ("hipResViewFormatSignedInt1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_2X32", + ("hipResViewFormatSignedInt2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SINT_4X32", + ("hipResViewFormatSignedInt4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X16", + ("hipResViewFormatHalf1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X16", + ("hipResViewFormatHalf2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X16", + ("hipResViewFormatHalf4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_1X32", + ("hipResViewFormatFloat1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_2X32", + ("hipResViewFormatFloat2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_FLOAT_4X32", + ("hipResViewFormatFloat4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_SIGNED_BC6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_DRIVER), + ), + ( + "CU_RES_VIEW_FORMAT_UNSIGNED_BC7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_DRIVER), + ), + ( + "CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_DRIVER), + ), + ("CU_STREAM_DEFAULT", ("hipStreamDefault", CONV_TYPE, API_DRIVER)), + ("CU_STREAM_NON_BLOCKING", ("hipStreamNonBlocking", CONV_TYPE, API_DRIVER)), + ( + "CU_STREAM_WAIT_VALUE_GEQ", + ("hipStreamWaitValueGeq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_EQ", + ("hipStreamWaitValueEq", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_AND", + ("hipStreamWaitValueAnd", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WAIT_VALUE_FLUSH", + ("hipStreamWaitValueFlush", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_DEFAULT", + ("hipStreamWriteValueDefault", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER", + ( + "hipStreamWriteValueNoMemoryBarrier", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_STREAM_MEM_OP_WAIT_VALUE_32", + ("hipStreamBatchMemOpWaitValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_WRITE_VALUE_32", + ("hipStreamBatchMemOpWriteValue32", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES", + ( + "hipStreamBatchMemOpFlushRemoteWrites", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGetErrorName", + ("hipGetErrorName___", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGetErrorString", + ("hipGetErrorString___", CONV_ERROR, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuInit", ("hipInit", CONV_INIT, API_DRIVER)), + ("cuDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_DRIVER)), + ("cuCtxCreate_v2", ("hipCtxCreate", CONV_CONTEXT, API_DRIVER)), + ("cuCtxDestroy_v2", ("hipCtxDestroy", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetApiVersion", ("hipCtxGetApiVersion", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCacheConfig", ("hipCtxGetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetCurrent", ("hipCtxGetCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetDevice", ("hipCtxGetDevice", CONV_CONTEXT, API_DRIVER)), + ("cuCtxGetFlags", ("hipCtxGetFlags", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxGetLimit", + ("hipCtxGetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxGetSharedMemConfig", + ("hipCtxGetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuCtxGetStreamPriorityRange", + ("hipCtxGetStreamPriorityRange", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuCtxPopCurrent_v2", ("hipCtxPopCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxPushCurrent_v2", ("hipCtxPushCurrent", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCacheConfig", ("hipCtxSetCacheConfig", CONV_CONTEXT, API_DRIVER)), + ("cuCtxSetCurrent", ("hipCtxSetCurrent", CONV_CONTEXT, API_DRIVER)), + ( + "cuCtxSetLimit", + ("hipCtxSetLimit", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuCtxSetSharedMemConfig", + ("hipCtxSetSharedMemConfig", CONV_CONTEXT, API_DRIVER), + ), + ("cuCtxSynchronize", ("hipCtxSynchronize", CONV_CONTEXT, API_DRIVER)), + ("cuCtxAttach", ("hipCtxAttach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxDetach", ("hipCtxDetach", CONV_CONTEXT, API_DRIVER, HIP_UNSUPPORTED)), + ("cuCtxEnablePeerAccess", ("hipCtxEnablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuCtxDisablePeerAccess", ("hipCtxDisablePeerAccess", CONV_PEER, API_DRIVER)), + ("cuDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_DRIVER)), + ( + "cuDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_PEER, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuDevicePrimaryCtxGetState", + ("hipDevicePrimaryCtxGetState", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRelease", + ("hipDevicePrimaryCtxRelease", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxReset", + ("hipDevicePrimaryCtxReset", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxRetain", + ("hipDevicePrimaryCtxRetain", CONV_CONTEXT, API_DRIVER), + ), + ( + "cuDevicePrimaryCtxSetFlags", + ("hipDevicePrimaryCtxSetFlags", CONV_CONTEXT, API_DRIVER), + ), + ("cuDeviceGet", ("hipGetDevice", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetName", ("hipDeviceGetName", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetCount", ("hipGetDeviceCount", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceGetByPCIBusId", ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_DRIVER)), + ("cuDeviceTotalMem_v2", ("hipDeviceTotalMem", CONV_DEVICE, API_DRIVER)), + ( + "cuDeviceComputeCapability", + ("hipDeviceComputeCapability", CONV_DEVICE, API_DRIVER), + ), + ("cuDeviceGetProperties", ("hipGetDeviceProperties", CONV_DEVICE, API_DRIVER)), + ("cuLinkAddData", ("hipLinkAddData", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkAddFile", ("hipLinkAddFile", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLinkComplete", + ("hipLinkComplete", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLinkCreate", ("hipLinkCreate", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLinkDestroy", ("hipLinkDestroy", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuModuleGetFunction", ("hipModuleGetFunction", CONV_MODULE, API_DRIVER)), + ("cuModuleGetGlobal_v2", ("hipModuleGetGlobal", CONV_MODULE, API_DRIVER)), + ( + "cuModuleGetSurfRef", + ("hipModuleGetSurfRef", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleGetTexRef", ("hipModuleGetTexRef", CONV_MODULE, API_DRIVER)), + ("cuModuleLoad", ("hipModuleLoad", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadData", ("hipModuleLoadData", CONV_MODULE, API_DRIVER)), + ("cuModuleLoadDataEx", ("hipModuleLoadDataEx", CONV_MODULE, API_DRIVER)), + ( + "cuModuleLoadFatBinary", + ("hipModuleLoadFatBinary", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuModuleUnload", ("hipModuleUnload", CONV_MODULE, API_DRIVER)), + ( + "CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("CU_EVENT_DEFAULT", ("hipEventDefault", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_BLOCKING_SYNC", ("hipEventBlockingSync", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_DISABLE_TIMING", ("hipEventDisableTiming", CONV_EVENT, API_DRIVER)), + ("CU_EVENT_INTERPROCESS", ("hipEventInterprocess", CONV_EVENT, API_DRIVER)), + ("cuEventCreate", ("hipEventCreate", CONV_EVENT, API_DRIVER)), + ("cuEventDestroy_v2", ("hipEventDestroy", CONV_EVENT, API_DRIVER)), + ("cuEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_DRIVER)), + ("cuEventQuery", ("hipEventQuery", CONV_EVENT, API_DRIVER)), + ("cuEventRecord", ("hipEventRecord", CONV_EVENT, API_DRIVER)), + ("cuEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_DRIVER)), + ( + "cuFuncGetAttribute", + ("hipFuncGetAttribute", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunchKernel", ("hipModuleLaunchKernel", CONV_MODULE, API_DRIVER)), + ( + "cuFuncSetBlockShape", + ("hipFuncSetBlockShape", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuFuncSetSharedSize", + ("hipFuncSetSharedSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuLaunch", ("hipLaunch", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuLaunchGrid", ("hipLaunchGrid", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuLaunchGridAsync", + ("hipLaunchGridAsync", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetf", ("hipParamSetf", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ("cuParamSeti", ("hipParamSeti", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuParamSetSize", + ("hipParamSetSize", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuParamSetv", ("hipParamSetv", CONV_MODULE, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_DRIVER, + ), + ), + ( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuOccupancyMaxPotentialBlockSize", + ("hipOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_DRIVER), + ), + ( + "cuOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_DRIVER)), + ( + "cuStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreate", + ("hipStreamCreate__", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamDestroy_v2", ("hipStreamDestroy", CONV_STREAM, API_DRIVER)), + ("cuStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_DRIVER)), + ( + "cuStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuStreamQuery", ("hipStreamQuery", CONV_STREAM, API_DRIVER)), + ("cuStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_DRIVER)), + ("cuStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_DRIVER)), + ( + "cuStreamWaitValue32", + ("hipStreamWaitValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamWriteValue32", + ("hipStreamWriteValue32", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuStreamBatchMemOp", + ("hipStreamBatchMemOp", CONV_STREAM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArray3DCreate", ("hipArray3DCreate", CONV_MEM, API_DRIVER)), + ( + "cuArray3DGetDescriptor", + ("hipArray3DGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuArrayCreate", ("hipArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuArrayDestroy", ("hipArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuArrayGetDescriptor", + ("hipArrayGetDescriptor", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcCloseMemHandle", + ("hipIpcCloseMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetEventHandle", + ("hipIpcGetEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcGetMemHandle", + ("hipIpcGetMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenEventHandle", + ("hipIpcOpenEventHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuIpcOpenMemHandle", + ("hipIpcOpenMemHandle", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAlloc_v2", ("hipMalloc", CONV_MEM, API_DRIVER)), + ("cuMemAllocHost", ("hipMemAllocHost", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemAllocManaged", + ("hipMemAllocManaged", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemAllocPitch", + ("hipMemAllocPitch__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy", ("hipMemcpy__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpy2D", ("hipMemcpy2D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy2DAsync", + ("hipMemcpy2DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy2DUnaligned", + ("hipMemcpy2DUnaligned", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpy3D", ("hipMemcpy3D__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpy3DAsync", + ("hipMemcpy3DAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeer", + ("hipMemcpy3DPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyAsync", ("hipMemcpyAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoA", ("hipMemcpyAtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoD", ("hipMemcpyAtoD", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyAtoH", ("hipMemcpyAtoH", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyAtoHAsync", + ("hipMemcpyAtoHAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyDtoA", ("hipMemcpyDtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemcpyDtoD_v2", ("hipMemcpyDtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoDAsync_v2", ("hipMemcpyDtoDAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoH_v2", ("hipMemcpyDtoH", CONV_MEM, API_DRIVER)), + ("cuMemcpyDtoHAsync_v2", ("hipMemcpyDtoHAsync", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoA", ("hipMemcpyHtoA", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemcpyHtoAAsync", + ("hipMemcpyHtoAAsync", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyHtoD_v2", ("hipMemcpyHtoD", CONV_MEM, API_DRIVER)), + ("cuMemcpyHtoDAsync_v2", ("hipMemcpyHtoDAsync", CONV_MEM, API_DRIVER)), + ( + "cuMemcpyPeerAsync", + ("hipMemcpyPeerAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemcpyPeer", ("hipMemcpyPeer__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ("cuMemFree_v2", ("hipFree", CONV_MEM, API_DRIVER)), + ("cuMemFreeHost", ("hipHostFree", CONV_MEM, API_DRIVER)), + ( + "cuMemGetAddressRange", + ("hipMemGetAddressRange", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemGetInfo_v2", ("hipMemGetInfo", CONV_MEM, API_DRIVER)), + ("cuMemHostAlloc", ("hipHostMalloc", CONV_MEM, API_DRIVER)), + ( + "cuMemHostGetDevicePointer", + ("hipMemHostGetDevicePointer", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemHostGetFlags", + ("hipMemHostGetFlags", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemHostRegister_v2", ("hipHostRegister", CONV_MEM, API_DRIVER)), + ("cuMemHostUnregister", ("hipHostUnregister", CONV_MEM, API_DRIVER)), + ("cuMemsetD16_v2", ("hipMemsetD16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD16Async", + ("hipMemsetD16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D16_v2", ("hipMemsetD2D16", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D16Async", + ("hipMemsetD2D16Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D32_v2", ("hipMemsetD2D32", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D32Async", + ("hipMemsetD2D32Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD2D8_v2", ("hipMemsetD2D8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD2D8Async", + ("hipMemsetD2D8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemsetD32_v2", ("hipMemset", CONV_MEM, API_DRIVER)), + ("cuMemsetD32Async", ("hipMemsetAsync", CONV_MEM, API_DRIVER)), + ("cuMemsetD8_v2", ("hipMemsetD8", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemsetD8Async", + ("hipMemsetD8Async", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayCreate", + ("hipMipmappedArrayCreate", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayDestroy", + ("hipMipmappedArrayDestroy", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMipmappedArrayGetLevel", + ("hipMipmappedArrayGetLevel", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemPrefetchAsync", + ("hipMemPrefetchAsync__", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuMemAdvise", ("hipMemAdvise", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerGetAttribute", + ("hipPointerGetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuPointerSetAttribute", + ("hipPointerSetAttribute", CONV_MEM, API_DRIVER, HIP_UNSUPPORTED), + ), + ("CU_TR_FILTER_MODE_POINT", ("hipFilterModePoint", CONV_TEX, API_DRIVER)), + ( + "CU_TR_FILTER_MODE_LINEAR", + ("hipFilterModeLinear", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddress", + ("hipTexRefGetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetAddressMode", + ("hipTexRefGetAddressMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetArray", + ("hipTexRefGetArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetBorderColor", + ("hipTexRefGetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFilterMode", + ("hipTexRefGetFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFlags", + ("hipTexRefGetFlags", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetFormat", + ("hipTexRefGetFormat", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMaxAnisotropy", + ("hipTexRefGetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapFilterMode", + ("hipTexRefGetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelBias", + ("hipTexRefGetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmapLevelClamp", + ("hipTexRefGetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefGetMipmappedArray", + ("hipTexRefGetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress", + ("hipTexRefSetAddress", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetAddress2D", + ("hipTexRefSetAddress2D", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetAddressMode", ("hipTexRefSetAddressMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetArray", ("hipTexRefSetArray", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetBorderColor", + ("hipTexRefSetBorderColor", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefSetFilterMode", ("hipTexRefSetFilterMode", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFlags", ("hipTexRefSetFlags", CONV_TEX, API_DRIVER)), + ("cuTexRefSetFormat", ("hipTexRefSetFormat", CONV_TEX, API_DRIVER)), + ( + "cuTexRefSetMaxAnisotropy", + ("hipTexRefSetMaxAnisotropy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapFilterMode", + ("hipTexRefSetMipmapFilterMode", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelBias", + ("hipTexRefSetMipmapLevelBias", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmapLevelClamp", + ("hipTexRefSetMipmapLevelClamp", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexRefSetMipmappedArray", + ("hipTexRefSetMipmappedArray", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuTexRefCreate", ("hipTexRefCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuTexRefDestroy", + ("hipTexRefDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefGetArray", + ("hipSurfRefGetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfRefSetArray", + ("hipSurfRefSetArray", CONV_SURFACE, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectCreate", + ("hipTexObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectDestroy", + ("hipTexObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceDesc", + ("hipTexObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetResourceViewDesc", + ("hipTexObjectGetResourceViewDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuTexObjectGetTextureDesc", + ("hipTexObjectGetTextureDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectCreate", + ("hipSurfObjectCreate", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectDestroy", + ("hipSurfObjectDestroy", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuSurfObjectGetResourceDesc", + ("hipSurfObjectGetResourceDesc", CONV_TEX, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuProfilerInitialize", + ("hipProfilerInitialize", CONV_OTHER, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuProfilerStart", ("hipProfilerStart", CONV_OTHER, API_DRIVER)), + ("cuProfilerStop", ("hipProfilerStop", CONV_OTHER, API_DRIVER)), + ( + "CU_GL_DEVICE_LIST_ALL", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_CURRENT_FRAME", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_DEVICE_LIST_NEXT_FRAME", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuGLGetDevices", ("hipGLGetDevices", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ("cuWGLGetDevice", ("hipWGLGetDevice", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "CU_GL_MAP_RESOURCE_FLAGS_NONE", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cuGLCtxCreate", ("hipGLCtxCreate", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ("cuGLInit", ("hipGLInit", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), + ( + "cuGLMapBufferObject", + ("hipGLMapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_ALL", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_DEVICE_LIST_NEXT_FRAME", + ("HIP_D3D9_DEVICE_LIST_NEXT_FRAME", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreate", + ("hipD3D9CtxCreate", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9CtxCreateOnDevice", + ("hipD3D9CtxCreateOnDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D9RegisterResource", + ("hipGraphicsD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_NONE", + ("HIP_D3D9_MAPRESOURCE_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D9_REGISTER_FLAGS_NONE", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D9_REGISTER_FLAGS_ARRAY", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedPointer", + ("hipD3D9ResourceGetMappedPointer", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_ALL", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_NONE", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_READONLY", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D10_REGISTER_FLAGS_NONE", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D10_REGISTER_FLAGS_ARRAY", + ("HIP_D3D10_REGISTER_FLAGS_ARRAY", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreate", + ("hipD3D10CtxCreate", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10CtxCreateOnDevice", + ("hipD3D10CtxCreateOnDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedArray", + ("hipD3D10ResourceGetMappedArray", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPitch", + ("hipD3D10ResourceGetMappedPitch", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD310ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_ALL", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "CU_D3D11_DEVICE_LIST_CURRENT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "CU_D3D11_DEVICE_LIST_NEXT_FRAME", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuD3D11CtxCreate", + ("hipD3D11CtxCreate", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11CtxCreateOnDevice", + ("hipD3D11CtxCreateOnDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuD3D11GetDirect3DDevice", + ("hipD3D11GetDirect3DDevice", CONV_D3D11, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuVDPAUCtxCreate", + ("hipVDPAUCtxCreate", CONV_VDPAU, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerAcquireFrame", + ("hipEGLStreamConsumerAcquireFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ( + "cuEGLStreamConsumerDisconnect", + ("hipEGLStreamConsumerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamConsumerReleaseFrame", + ("hipEGLStreamConsumerReleaseFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerPresentFrame", + ("hipEGLStreamProducerPresentFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_DRIVER, HIP_UNSUPPORTED), + ), + ( + "cuGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_DRIVER, + HIP_UNSUPPORTED, + ), + ), + ("cudaDataType_t", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaDataType", ("hipDataType_t", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_R_16F", ("hipR16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_C_16F", ("hipC16F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_R_32F", ("hipR32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_C_32F", ("hipC32F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_R_64F", ("hipR64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_C_64F", ("hipC64F", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_R_8I", ("hipR8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_C_8I", ("hipC8I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_R_8U", ("hipR8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_C_8U", ("hipC8U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_R_32I", ("hipR32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_C_32I", ("hipC32I", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_R_32U", ("hipR32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ("CUDA_C_32U", ("hipC32U", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "MAJOR_VERSION", + ("hipLibraryMajorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "MINOR_VERSION", + ("hipLibraryMinorVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "PATCH_LEVEL", + ("hipLibraryPatchVersion", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachGlobal", + ("hipMemAttachGlobal", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachHost", + ("hipMemAttachHost", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAttachSingle", + ("hipMemAttachSingle", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDefault", + ("hipOccupancyDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyDisableCachingOverride", + ( + "hipOccupancyDisableCachingOverride", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaGetLastError", ("hipGetLastError", CONV_ERROR, API_RUNTIME)), + ("cudaPeekAtLastError", ("hipPeekAtLastError", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorName", ("hipGetErrorName", CONV_ERROR, API_RUNTIME)), + ("cudaGetErrorString", ("hipGetErrorString", CONV_ERROR, API_RUNTIME)), + ("cudaMemcpy3DParms", ("hipMemcpy3DParms", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DPeerParms", + ("hipMemcpy3DPeerParms", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy", ("hipMemcpy", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToArray", ("hipMemcpyToArray", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbol", ("hipMemcpyToSymbol", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyToSymbolAsync", ("hipMemcpyToSymbolAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyAsync", ("hipMemcpyAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2D", ("hipMemcpy2D", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DAsync", ("hipMemcpy2DAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpy2DToArray", ("hipMemcpy2DToArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy2DArrayToArray", + ("hipMemcpy2DArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArray", + ("hipMemcpy2DFromArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DFromArrayAsync", + ("hipMemcpy2DFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy2DToArrayAsync", + ("hipMemcpy2DToArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpy3D", ("hipMemcpy3D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpy3DAsync", + ("hipMemcpy3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeer", + ("hipMemcpy3DPeer", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpy3DPeerAsync", + ("hipMemcpy3DPeerAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyArrayToArray", + ("hipMemcpyArrayToArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemcpyFromArrayAsync", + ("hipMemcpyFromArrayAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemcpyFromSymbol", ("hipMemcpyFromSymbol", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyFromSymbolAsync", + ("hipMemcpyFromSymbolAsync", CONV_MEM, API_RUNTIME), + ), + ("cudaMemAdvise", ("hipMemAdvise", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemRangeGetAttribute", + ("hipMemRangeGetAttribute", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeGetAttributes", + ("hipMemRangeGetAttributes", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetReadMostly", + ("hipMemAdviseSetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetReadMostly", + ("hipMemAdviseUnsetReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseSetPreferredLocation", + ( + "hipMemAdviseSetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseUnsetPreferredLocation", + ( + "hipMemAdviseUnsetPreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemAdviseSetAccessedBy", + ("hipMemAdviseSetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemAdviseUnsetAccessedBy", + ("hipMemAdviseUnsetAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeReadMostly", + ("hipMemRangeAttributeReadMostly", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributePreferredLocation", + ( + "hipMemRangeAttributePreferredLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaMemRangeAttributeAccessedBy", + ("hipMemRangeAttributeAccessedBy", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemRangeAttributeLastPrefetchLocation", + ( + "hipMemRangeAttributeLastPrefetchLocation", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaMemcpyHostToHost", ("hipMemcpyHostToHost", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyHostToDevice", ("hipMemcpyHostToDevice", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyDeviceToHost", ("hipMemcpyDeviceToHost", CONV_MEM, API_RUNTIME)), + ( + "cudaMemcpyDeviceToDevice", + ("hipMemcpyDeviceToDevice", CONV_MEM, API_RUNTIME), + ), + ("cudaMemcpyDefault", ("hipMemcpyDefault", CONV_MEM, API_RUNTIME)), + ("cudaMemset", ("hipMemset", CONV_MEM, API_RUNTIME)), + ("cudaMemsetAsync", ("hipMemsetAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemset2D", ("hipMemset2D", CONV_MEM, API_RUNTIME)), + ( + "cudaMemset2DAsync", + ("hipMemset2DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemset3D", ("hipMemset3D", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaMemset3DAsync", + ("hipMemset3DAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMemGetInfo", ("hipMemGetInfo", CONV_MEM, API_RUNTIME)), + ( + "cudaArrayGetInfo", + ("hipArrayGetInfo", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFreeMipmappedArray", + ("hipFreeMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetMipmappedArrayLevel", + ("hipGetMipmappedArrayLevel", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolAddress", + ("hipGetSymbolAddress", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSymbolSize", + ("hipGetSymbolSize", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMemPrefetchAsync", + ("hipMemPrefetchAsync", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocHost", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMallocArray", ("hipMallocArray", CONV_MEM, API_RUNTIME)), + ("cudaMalloc", ("hipMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3D", ("hipMalloc3D", CONV_MEM, API_RUNTIME)), + ("cudaMalloc3DArray", ("hipMalloc3DArray", CONV_MEM, API_RUNTIME)), + ( + "cudaMallocManaged", + ("hipMallocManaged", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaMallocMipmappedArray", + ("hipMallocMipmappedArray", CONV_MEM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaMallocPitch", ("hipMallocPitch", CONV_MEM, API_RUNTIME)), + ("cudaFreeHost", ("hipHostFree", CONV_MEM, API_RUNTIME)), + ("cudaFreeArray", ("hipFreeArray", CONV_MEM, API_RUNTIME)), + ("cudaFree", ("hipFree", CONV_MEM, API_RUNTIME)), + ("cudaHostRegister", ("hipHostRegister", CONV_MEM, API_RUNTIME)), + ("cudaHostUnregister", ("hipHostUnregister", CONV_MEM, API_RUNTIME)), + ("cudaHostAlloc", ("hipHostMalloc", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeHost", ("hipMemoryTypeHost", CONV_MEM, API_RUNTIME)), + ("cudaMemoryTypeDevice", ("hipMemoryTypeDevice", CONV_MEM, API_RUNTIME)), + ("make_cudaExtent", ("make_hipExtent", CONV_MEM, API_RUNTIME)), + ("make_cudaPitchedPtr", ("make_hipPitchedPtr", CONV_MEM, API_RUNTIME)), + ("make_cudaPos", ("make_hipPos", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocDefault", ("hipHostMallocDefault", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocPortable", ("hipHostMallocPortable", CONV_MEM, API_RUNTIME)), + ("cudaHostAllocMapped", ("hipHostMallocMapped", CONV_MEM, API_RUNTIME)), + ( + "cudaHostAllocWriteCombined", + ("hipHostMallocWriteCombined", CONV_MEM, API_RUNTIME), + ), + ("cudaHostGetFlags", ("hipHostGetFlags", CONV_MEM, API_RUNTIME)), + ("cudaHostRegisterDefault", ("hipHostRegisterDefault", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterPortable", + ("hipHostRegisterPortable", CONV_MEM, API_RUNTIME), + ), + ("cudaHostRegisterMapped", ("hipHostRegisterMapped", CONV_MEM, API_RUNTIME)), + ( + "cudaHostRegisterIoMemory", + ("hipHostRegisterIoMemory", CONV_MEM, API_RUNTIME), + ), + # ("warpSize", ("hipWarpSize", CONV_SPECIAL_FUNC, API_RUNTIME), (HIP actually uses warpSize...), + ("cudaEventCreate", ("hipEventCreate", CONV_EVENT, API_RUNTIME)), + ( + "cudaEventCreateWithFlags", + ("hipEventCreateWithFlags", CONV_EVENT, API_RUNTIME), + ), + ("cudaEventDestroy", ("hipEventDestroy", CONV_EVENT, API_RUNTIME)), + ("cudaEventRecord", ("hipEventRecord", CONV_EVENT, API_RUNTIME)), + ("cudaEventElapsedTime", ("hipEventElapsedTime", CONV_EVENT, API_RUNTIME)), + ("cudaEventSynchronize", ("hipEventSynchronize", CONV_EVENT, API_RUNTIME)), + ("cudaEventQuery", ("hipEventQuery", CONV_EVENT, API_RUNTIME)), + ("cudaEventDefault", ("hipEventDefault", CONV_EVENT, API_RUNTIME)), + ("cudaEventBlockingSync", ("hipEventBlockingSync", CONV_EVENT, API_RUNTIME)), + ("cudaEventDisableTiming", ("hipEventDisableTiming", CONV_EVENT, API_RUNTIME)), + ("cudaEventInterprocess", ("hipEventInterprocess", CONV_EVENT, API_RUNTIME)), + ("cudaStreamCreate", ("hipStreamCreate", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamCreateWithFlags", + ("hipStreamCreateWithFlags", CONV_STREAM, API_RUNTIME), + ), + ( + "cudaStreamCreateWithPriority", + ("hipStreamCreateWithPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamDestroy", ("hipStreamDestroy", CONV_STREAM, API_RUNTIME)), + ("cudaStreamWaitEvent", ("hipStreamWaitEvent", CONV_STREAM, API_RUNTIME)), + ("cudaStreamSynchronize", ("hipStreamSynchronize", CONV_STREAM, API_RUNTIME)), + ("cudaStreamGetFlags", ("hipStreamGetFlags", CONV_STREAM, API_RUNTIME)), + ("cudaStreamQuery", ("hipStreamQuery", CONV_STREAM, API_RUNTIME)), + ("cudaStreamAddCallback", ("hipStreamAddCallback", CONV_STREAM, API_RUNTIME)), + ( + "cudaStreamAttachMemAsync", + ("hipStreamAttachMemAsync", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaStreamGetPriority", + ("hipStreamGetPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaStreamDefault", ("hipStreamDefault", CONV_TYPE, API_RUNTIME)), + ("cudaStreamNonBlocking", ("hipStreamNonBlocking", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceSynchronize", ("hipDeviceSynchronize", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceReset", ("hipDeviceReset", CONV_DEVICE, API_RUNTIME)), + ("cudaSetDevice", ("hipSetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDevice", ("hipGetDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaGetDeviceCount", ("hipGetDeviceCount", CONV_DEVICE, API_RUNTIME)), + ("cudaChooseDevice", ("hipChooseDevice", CONV_DEVICE, API_RUNTIME)), + ("cudaThreadExit", ("hipDeviceReset", CONV_THREAD, API_RUNTIME)), + ( + "cudaThreadGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadGetLimit", + ("hipThreadGetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaThreadSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_THREAD, API_RUNTIME), + ), + ( + "cudaThreadSetLimit", + ("hipThreadSetLimit", CONV_THREAD, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaThreadSynchronize", ("hipDeviceSynchronize", CONV_THREAD, API_RUNTIME)), + ("cudaDeviceGetAttribute", ("hipDeviceGetAttribute", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDevAttrMaxThreadsPerBlock", + ("hipDeviceAttributeMaxThreadsPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimX", + ("hipDeviceAttributeMaxBlockDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimY", + ("hipDeviceAttributeMaxBlockDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxBlockDimZ", + ("hipDeviceAttributeMaxBlockDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimX", + ("hipDeviceAttributeMaxGridDimX", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimY", + ("hipDeviceAttributeMaxGridDimY", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxGridDimZ", + ("hipDeviceAttributeMaxGridDimZ", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxSharedMemoryPerBlock", + ("hipDeviceAttributeMaxSharedMemoryPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTotalConstantMemory", + ("hipDeviceAttributeTotalConstantMemory", CONV_TYPE, API_RUNTIME), + ), + ("cudaDevAttrWarpSize", ("hipDeviceAttributeWarpSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrMaxPitch", + ("hipDeviceAttributeMaxPitch", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMaxRegistersPerBlock", + ("hipDeviceAttributeMaxRegistersPerBlock", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrClockRate", + ("hipDeviceAttributeClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTextureAlignment", + ( + "hipDeviceAttributeTextureAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGpuOverlap", + ("hipDeviceAttributeGpuOverlap", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMultiProcessorCount", + ("hipDeviceAttributeMultiprocessorCount", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrKernelExecTimeout", + ( + "hipDeviceAttributeKernelExecTimeout", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIntegrated", + ("hipDeviceAttributeIntegrated", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrCanMapHostMemory", + ( + "hipDeviceAttributeCanMapHostMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeMode", + ("hipDeviceAttributeComputeMode", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DWidth", + ( + "hipDeviceAttributeMaxTexture1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DWidth", + ( + "hipDeviceAttributeMaxTexture2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DHeight", + ( + "hipDeviceAttributeMaxTexture2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidth", + ( + "hipDeviceAttributeMaxTexture3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeight", + ( + "hipDeviceAttributeMaxTexture3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepth", + ( + "hipDeviceAttributeMaxTexture3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredHeight", + ( + "hipDeviceAttributeMaxTexture2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSurfaceAlignment", + ( + "hipDeviceAttributeSurfaceAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentKernels", + ("hipDeviceAttributeConcurrentKernels", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrEccEnabled", + ("hipDeviceAttributeEccEnabled", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDevAttrPciBusId", ("hipDeviceAttributePciBusId", CONV_TYPE, API_RUNTIME)), + ( + "cudaDevAttrPciDeviceId", + ("hipDeviceAttributePciDeviceId", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrTccDriver", + ("hipDeviceAttributeTccDriver", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrMemoryClockRate", + ("hipDeviceAttributeMemoryClockRate", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrGlobalMemoryBusWidth", + ("hipDeviceAttributeMemoryBusWidth", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrL2CacheSize", + ("hipDeviceAttributeL2CacheSize", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxThreadsPerMultiProcessor", + ("hipDeviceAttributeMaxThreadsPerMultiProcessor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrAsyncEngineCount", + ( + "hipDeviceAttributeAsyncEngineCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrUnifiedAddressing", + ( + "hipDeviceAttributeUnifiedAddressing", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredWidth", + ( + "hipDeviceAttributeMaxTexture1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLayeredLayers", + ( + "hipDeviceAttributeMaxTexture1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherWidth", + ( + "hipDeviceAttributeMaxTexture2DGatherWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DGatherHeight", + ( + "hipDeviceAttributeMaxTexture2DGatherHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DWidthAlt", + ( + "hipDeviceAttributeMaxTexture3DWidthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DHeightAlt", + ( + "hipDeviceAttributeMaxTexture3DHeightAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture3DDepthAlt", + ( + "hipDeviceAttributeMaxTexture3DDepthAlternate", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPciDomainId", + ("hipDeviceAttributePciDomainId", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevAttrTexturePitchAlignment", + ( + "hipDeviceAttributeTexturePitchAlignment", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapWidth", + ( + "hipDeviceAttributeMaxTextureCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTextureCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxTextureCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DWidth", + ( + "hipDeviceAttributeMaxSurface1DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DWidth", + ( + "hipDeviceAttributeMaxSurface2DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DHeight", + ( + "hipDeviceAttributeMaxSurface2DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DWidth", + ( + "hipDeviceAttributeMaxSurface3DWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DHeight", + ( + "hipDeviceAttributeMaxSurface3DHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface3DDepth", + ( + "hipDeviceAttributeMaxSurface3DDepth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface1DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface1DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface1DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredWidth", + ( + "hipDeviceAttributeMaxSurface2DLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredHeight", + ( + "hipDeviceAttributeMaxSurface2DLayeredHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurface2DLayeredLayers", + ( + "hipDeviceAttributeMaxSurface2DLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredWidth", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSurfaceCubemapLayeredLayers", + ( + "hipDeviceAttributeMaxSurfaceCubemapLayeredLayers", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture1DLinearWidth", + ( + "hipDeviceAttributeMaxTexture1DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearWidth", + ( + "hipDeviceAttributeMaxTexture2DLinearWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearHeight", + ( + "hipDeviceAttributeMaxTexture2DLinearHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DLinearPitch", + ( + "hipDeviceAttributeMaxTexture2DLinearPitch", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture2DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxTexture2DMipmappedHeight", + ( + "hipDeviceAttributeMaxTexture2DMipmappedHeight", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputeCapabilityMajor", + ("hipDeviceAttributeComputeCapabilityMajor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrComputeCapabilityMinor", + ("hipDeviceAttributeComputeCapabilityMinor", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMaxTexture1DMipmappedWidth", + ( + "hipDeviceAttributeMaxTexture1DMipmappedWidth", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrStreamPrioritiesSupported", + ( + "hipDeviceAttributeStreamPrioritiesSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrGlobalL1CacheSupported", + ( + "hipDeviceAttributeGlobalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrLocalL1CacheSupported", + ( + "hipDeviceAttributeLocalL1CacheSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrMaxSharedMemoryPerMultiprocessor", + ( + "hipDeviceAttributeMaxSharedMemoryPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + ), + ), + ( + "cudaDevAttrMaxRegistersPerMultiprocessor", + ( + "hipDeviceAttributeMaxRegistersPerMultiprocessor", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrManagedMemory", + ( + "hipDeviceAttributeManagedMemory", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrIsMultiGpuBoard", + ("hipDeviceAttributeIsMultiGpuBoard", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDevAttrMultiGpuBoardGroupID", + ( + "hipDeviceAttributeMultiGpuBoardGroupID", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrHostNativeAtomicSupported", + ( + "hipDeviceAttributeHostNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrSingleToDoublePrecisionPerfRatio", + ( + "hipDeviceAttributeSingleToDoublePrecisionPerfRatio", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrPageableMemoryAccess", + ( + "hipDeviceAttributePageableMemoryAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrConcurrentManagedAccess", + ( + "hipDeviceAttributeConcurrentManagedAccess", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrComputePreemptionSupported", + ( + "hipDeviceAttributeComputePreemptionSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevAttrCanUseHostPointerForRegisteredMem", + ( + "hipDeviceAttributeCanUseHostPointerForRegisteredMem", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaPointerGetAttributes", + ("hipPointerGetAttributes", CONV_MEM, API_RUNTIME), + ), + ( + "cudaHostGetDevicePointer", + ("hipHostGetDevicePointer", CONV_MEM, API_RUNTIME), + ), + ( + "cudaGetDeviceProperties", + ("hipGetDeviceProperties", CONV_DEVICE, API_RUNTIME), + ), + ("cudaDeviceGetPCIBusId", ("hipDeviceGetPCIBusId", CONV_DEVICE, API_RUNTIME)), + ( + "cudaDeviceGetByPCIBusId", + ("hipDeviceGetByPCIBusId", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetStreamPriorityRange", + ( + "hipDeviceGetStreamPriorityRange", + CONV_DEVICE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaSetValidDevices", + ("hipSetValidDevices", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDevP2PAttrPerformanceRank", + ( + "hipDeviceP2PAttributePerformanceRank", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrAccessSupported", + ( + "hipDeviceP2PAttributeAccessSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDevP2PAttrNativeAtomicSupported", + ( + "hipDeviceP2PAttributeNativeAtomicSupported", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaDeviceGetP2PAttribute", + ("hipDeviceGetP2PAttribute", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeDefault", + ("hipComputeModeDefault", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusive", + ("hipComputeModeExclusive", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeProhibited", + ("hipComputeModeProhibited", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaComputeModeExclusiveProcess", + ("hipComputeModeExclusiveProcess", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetDeviceFlags", + ("hipGetDeviceFlags", CONV_DEVICE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaSetDeviceFlags", ("hipSetDeviceFlags", CONV_DEVICE, API_RUNTIME)), + ("cudaDeviceScheduleAuto", ("hipDeviceScheduleAuto", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleSpin", ("hipDeviceScheduleSpin", CONV_TYPE, API_RUNTIME)), + ("cudaDeviceScheduleYield", ("hipDeviceScheduleYield", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleBlockingSync", + ("hipDeviceScheduleBlockingSync", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceScheduleMask", + ("hipDeviceScheduleMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMapHost", ("hipDeviceMapHost", CONV_TYPE, API_RUNTIME)), + ( + "cudaDeviceLmemResizeToMax", + ("hipDeviceLmemResizeToMax", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDeviceMask", ("hipDeviceMask", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaDeviceSetCacheConfig", + ("hipDeviceSetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaDeviceGetCacheConfig", + ("hipDeviceGetCacheConfig", CONV_CACHE, API_RUNTIME), + ), + ("cudaFuncSetCacheConfig", ("hipFuncSetCacheConfig", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferNone", + ("hipFuncCachePreferNone", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncCachePreferShared", + ("hipFuncCachePreferShared", CONV_CACHE, API_RUNTIME), + ), + ("cudaFuncCachePreferL1", ("hipFuncCachePreferL1", CONV_CACHE, API_RUNTIME)), + ( + "cudaFuncCachePreferEqual", + ("hipFuncCachePreferEqual", CONV_CACHE, API_RUNTIME), + ), + ( + "cudaFuncGetAttributes", + ("hipFuncGetAttributes", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFuncSetSharedMemConfig", + ("hipFuncSetSharedMemConfig", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetParameterBuffer", + ("hipGetParameterBuffer", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForDevice", + ("hipSetDoubleForDevice", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaSetDoubleForHost", + ("hipSetDoubleForHost", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaConfigureCall", + ("hipConfigureCall", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLaunch", ("hipLaunch", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)), + ( + "cudaSetupArgument", + ("hipSetupArgument", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaDriverGetVersion", ("hipDriverGetVersion", CONV_VERSION, API_RUNTIME)), + ( + "cudaRuntimeGetVersion", + ("hipRuntimeGetVersion", CONV_VERSION, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaOccupancyMaxPotentialBlockSize", + ("hipOccupancyMaxPotentialBlockSize", CONV_OCCUPANCY, API_RUNTIME), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessor", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessor", + CONV_OCCUPANCY, + API_RUNTIME, + ), + ), + ( + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + ( + "hipOccupancyMaxActiveBlocksPerMultiprocessorWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMem", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMem", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + ( + "hipOccupancyMaxPotentialBlockSizeVariableSMemWithFlags", + CONV_OCCUPANCY, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceCanAccessPeer", ("hipDeviceCanAccessPeer", CONV_PEER, API_RUNTIME)), + ( + "cudaDeviceDisablePeerAccess", + ("hipDeviceDisablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ( + "cudaDeviceEnablePeerAccess", + ("hipDeviceEnablePeerAccess", CONV_PEER, API_RUNTIME), + ), + ("cudaMemcpyPeerAsync", ("hipMemcpyPeerAsync", CONV_MEM, API_RUNTIME)), + ("cudaMemcpyPeer", ("hipMemcpyPeer", CONV_MEM, API_RUNTIME)), + ( + "cudaIpcMemLazyEnablePeerAccess", + ("hipIpcMemLazyEnablePeerAccess", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaDeviceSetSharedMemConfig", + ("hipDeviceSetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaDeviceGetSharedMemConfig", + ("hipDeviceGetSharedMemConfig", CONV_DEVICE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeDefault", + ("hipSharedMemBankSizeDefault", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeFourByte", + ("hipSharedMemBankSizeFourByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaSharedMemBankSizeEightByte", + ("hipSharedMemBankSizeEightByte", CONV_TYPE, API_RUNTIME), + ), + ( + "cudaLimitStackSize", + ("hipLimitStackSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitPrintfFifoSize", + ("hipLimitPrintfFifoSize", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaLimitMallocHeapSize", ("hipLimitMallocHeapSize", CONV_TYPE, API_RUNTIME)), + ( + "cudaLimitDevRuntimeSyncDepth", + ("hipLimitDevRuntimeSyncDepth", CONV_TYPE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaLimitDevRuntimePendingLaunchCount", + ( + "hipLimitDevRuntimePendingLaunchCount", + CONV_TYPE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaDeviceGetLimit", ("hipDeviceGetLimit", CONV_DEVICE, API_RUNTIME)), + ( + "cudaProfilerInitialize", + ("hipProfilerInitialize", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaProfilerStart", ("hipProfilerStart", CONV_OTHER, API_RUNTIME)), + ("cudaProfilerStop", ("hipProfilerStop", CONV_OTHER, API_RUNTIME)), + ( + "cudaKeyValuePair", + ("hipKeyValuePair", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED), + ), + ("cudaCSV", ("hipCSV", CONV_OTHER, API_RUNTIME, HIP_UNSUPPORTED)), + ("cudaReadModeElementType", ("hipReadModeElementType", CONV_TEX, API_RUNTIME)), + ( + "cudaReadModeNormalizedFloat", + ("hipReadModeNormalizedFloat", CONV_TEX, API_RUNTIME), + ), + ("cudaFilterModePoint", ("hipFilterModePoint", CONV_TEX, API_RUNTIME)), + ("cudaFilterModeLinear", ("hipFilterModeLinear", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture", ("hipBindTexture", CONV_TEX, API_RUNTIME)), + ("cudaUnbindTexture", ("hipUnbindTexture", CONV_TEX, API_RUNTIME)), + ("cudaBindTexture2D", ("hipBindTexture2D", CONV_TEX, API_RUNTIME)), + ("cudaBindTextureToArray", ("hipBindTextureToArray", CONV_TEX, API_RUNTIME)), + ( + "cudaBindTextureToMipmappedArray", + ("hipBindTextureToMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureAlignmentOffset", + ("hipGetTextureAlignmentOffset", CONV_TEX, API_RUNTIME), + ), + ("cudaGetTextureReference", ("hipGetTextureReference", CONV_TEX, API_RUNTIME)), + ( + "cudaChannelFormatKindSigned", + ("hipChannelFormatKindSigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindUnsigned", + ("hipChannelFormatKindUnsigned", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindFloat", + ("hipChannelFormatKindFloat", CONV_TEX, API_RUNTIME), + ), + ( + "cudaChannelFormatKindNone", + ("hipChannelFormatKindNone", CONV_TEX, API_RUNTIME), + ), + ("cudaCreateChannelDesc", ("hipCreateChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaGetChannelDesc", ("hipGetChannelDesc", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypeArray", ("hipResourceTypeArray", CONV_TEX, API_RUNTIME)), + ( + "cudaResourceTypeMipmappedArray", + ("hipResourceTypeMipmappedArray", CONV_TEX, API_RUNTIME), + ), + ("cudaResourceTypeLinear", ("hipResourceTypeLinear", CONV_TEX, API_RUNTIME)), + ("cudaResourceTypePitch2D", ("hipResourceTypePitch2D", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatNone", ("hipResViewFormatNone", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedChar1", + ("hipResViewFormatUnsignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar2", + ("hipResViewFormatUnsignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedChar4", + ("hipResViewFormatUnsignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar1", + ("hipResViewFormatSignedChar1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar2", + ("hipResViewFormatSignedChar2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedChar4", + ("hipResViewFormatSignedChar4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort1", + ("hipResViewFormatUnsignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort2", + ("hipResViewFormatUnsignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedShort4", + ("hipResViewFormatUnsignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort1", + ("hipResViewFormatSignedShort1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort2", + ("hipResViewFormatSignedShort2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedShort4", + ("hipResViewFormatSignedShort4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt1", + ("hipResViewFormatUnsignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt2", + ("hipResViewFormatUnsignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedInt4", + ("hipResViewFormatUnsignedInt4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt1", + ("hipResViewFormatSignedInt1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt2", + ("hipResViewFormatSignedInt2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedInt4", + ("hipResViewFormatSignedInt4", CONV_TEX, API_RUNTIME), + ), + ("cudaResViewFormatHalf1", ("hipResViewFormatHalf1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf2", ("hipResViewFormatHalf2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatHalf4", ("hipResViewFormatHalf4", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat1", ("hipResViewFormatFloat1", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat2", ("hipResViewFormatFloat2", CONV_TEX, API_RUNTIME)), + ("cudaResViewFormatFloat4", ("hipResViewFormatFloat4", CONV_TEX, API_RUNTIME)), + ( + "cudaResViewFormatUnsignedBlockCompressed1", + ("hipResViewFormatUnsignedBlockCompressed1", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed2", + ("hipResViewFormatUnsignedBlockCompressed2", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed3", + ("hipResViewFormatUnsignedBlockCompressed3", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed4", + ("hipResViewFormatUnsignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed4", + ("hipResViewFormatSignedBlockCompressed4", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed5", + ("hipResViewFormatUnsignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed5", + ("hipResViewFormatSignedBlockCompressed5", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed6H", + ("hipResViewFormatUnsignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatSignedBlockCompressed6H", + ("hipResViewFormatSignedBlockCompressed6H", CONV_TEX, API_RUNTIME), + ), + ( + "cudaResViewFormatUnsignedBlockCompressed7", + ("hipResViewFormatUnsignedBlockCompressed7", CONV_TEX, API_RUNTIME), + ), + ("cudaAddressModeWrap", ("hipAddressModeWrap", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeClamp", ("hipAddressModeClamp", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeMirror", ("hipAddressModeMirror", CONV_TEX, API_RUNTIME)), + ("cudaAddressModeBorder", ("hipAddressModeBorder", CONV_TEX, API_RUNTIME)), + ("cudaCreateTextureObject", ("hipCreateTextureObject", CONV_TEX, API_RUNTIME)), + ( + "cudaDestroyTextureObject", + ("hipDestroyTextureObject", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceDesc", + ("hipGetTextureObjectResourceDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectResourceViewDesc", + ("hipGetTextureObjectResourceViewDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaGetTextureObjectTextureDesc", + ("hipGetTextureObjectTextureDesc", CONV_TEX, API_RUNTIME), + ), + ( + "cudaBindSurfaceToArray", + ("hipBindSurfaceToArray", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceReference", + ("hipGetSurfaceReference", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeZero", + ("hipBoundaryModeZero", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeClamp", + ("hipBoundaryModeClamp", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaBoundaryModeTrap", + ("hipBoundaryModeTrap", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeForced", + ("hipFormatModeForced", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaFormatModeAuto", + ("hipFormatModeAuto", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaCreateSurfaceObject", + ("hipCreateSurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaDestroySurfaceObject", + ("hipDestroySurfaceObject", CONV_SURFACE, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGetSurfaceObjectResourceDesc", + ( + "hipGetSurfaceObjectResourceDesc", + CONV_SURFACE, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cudaIpcCloseMemHandle", ("hipIpcCloseMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetEventHandle", ("hipIpcGetEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcGetMemHandle", ("hipIpcGetMemHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenEventHandle", ("hipIpcOpenEventHandle", CONV_DEVICE, API_RUNTIME)), + ("cudaIpcOpenMemHandle", ("hipIpcOpenMemHandle", CONV_DEVICE, API_RUNTIME)), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapResources", + ("hipGraphicsMapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedMipmappedArray", + ( + "hipGraphicsResourceGetMappedMipmappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceGetMappedPointer", + ( + "hipGraphicsResourceGetMappedPointer", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsResourceSetMapFlags", + ( + "hipGraphicsResourceSetMapFlags", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsSubResourceGetMappedArray", + ( + "hipGraphicsSubResourceGetMappedArray", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsUnmapResources", + ("hipGraphicsUnmapResources", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsUnregisterResource", + ( + "hipGraphicsUnregisterResource", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveX", + ( + "hipGraphicsCubeFacePositiveX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeX", + ( + "hipGraphicsCubeFaceNegativeX", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveY", + ( + "hipGraphicsCubeFacePositiveY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeY", + ( + "hipGraphicsCubeFaceNegativeY", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFacePositiveZ", + ( + "hipGraphicsCubeFacePositiveZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsCubeFaceNegativeZ", + ( + "hipGraphicsCubeFaceNegativeZ", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsNone", + ("hipGraphicsMapFlagsNone", CONV_GRAPHICS, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsMapFlagsReadOnly", + ( + "hipGraphicsMapFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsMapFlagsWriteDiscard", + ( + "hipGraphicsMapFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsNone", + ( + "hipGraphicsRegisterFlagsNone", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsReadOnly", + ( + "hipGraphicsRegisterFlagsReadOnly", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsWriteDiscard", + ( + "hipGraphicsRegisterFlagsWriteDiscard", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsSurfaceLoadStore", + ( + "hipGraphicsRegisterFlagsSurfaceLoadStore", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsRegisterFlagsTextureGather", + ( + "hipGraphicsRegisterFlagsTextureGather", + CONV_GRAPHICS, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLDeviceListAll", + ("HIP_GL_DEVICE_LIST_ALL", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListCurrentFrame", + ("HIP_GL_DEVICE_LIST_CURRENT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLDeviceListNextFrame", + ("HIP_GL_DEVICE_LIST_NEXT_FRAME", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLGetDevices", + ("hipGLGetDevices", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterBuffer", + ("hipGraphicsGLRegisterBuffer", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsGLRegisterImage", + ("hipGraphicsGLRegisterImage", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaWGLGetDevice", + ("hipWGLGetDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsNone", + ("HIP_GL_MAP_RESOURCE_FLAGS_NONE", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapFlagsReadOnly", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_READ_ONLY", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapFlagsWriteDiscard", + ( + "HIP_GL_MAP_RESOURCE_FLAGS_WRITE_DISCARD", + CONV_GL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGLMapBufferObject", + ("hipGLMapBufferObject__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLMapBufferObjectAsync", + ("hipGLMapBufferObjectAsync__", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLRegisterBufferObject", + ("hipGLRegisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetBufferObjectMapFlags", + ("hipGLSetBufferObjectMapFlags", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLSetGLDevice", + ("hipGLSetGLDevice", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObject", + ("hipGLUnmapBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnmapBufferObjectAsync", + ("hipGLUnmapBufferObjectAsync", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGLUnregisterBufferObject", + ("hipGLUnregisterBufferObject", CONV_GL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListAll", + ("HIP_D3D9_DEVICE_LIST_ALL", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9DeviceListCurrentFrame", + ( + "HIP_D3D9_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9DeviceListNextFrame", + ( + "HIP_D3D9_DEVICE_LIST_NEXT_FRAME", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9GetDevice", + ("hipD3D9GetDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDevices", + ("hipD3D9GetDevices", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9GetDirect3DDevice", + ("hipD3D9GetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9SetDirect3DDevice", + ("hipD3D9SetDirect3DDevice", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D9RegisterResource", + ( + "hipGraphicsD3D9RegisterResource", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlags", + ("hipD3D9MapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapFlagsNone", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_NONE", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsReadOnly", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9MapFlagsWriteDiscard", + ( + "HIP_D3D9_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9RegisterFlagsNone", + ("HIP_D3D9_REGISTER_FLAGS_NONE", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterFlagsArray", + ("HIP_D3D9_REGISTER_FLAGS_ARRAY", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9MapResources", + ("hipD3D9MapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9RegisterResource", + ("hipD3D9RegisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedArray", + ("hipD3D9ResourceGetMappedArray", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPitch", + ("hipD3D9ResourceGetMappedPitch", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetMappedPointer", + ( + "hipD3D9ResourceGetMappedPointer", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceGetMappedSize", + ("hipD3D9ResourceGetMappedSize", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9ResourceGetSurfaceDimensions", + ( + "hipD3D9ResourceGetSurfaceDimensions", + CONV_D3D9, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D9ResourceSetMapFlags", + ("hipD3D9ResourceSetMapFlags", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnmapResources", + ("hipD3D9UnmapResources", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D9UnregisterResource", + ("hipD3D9UnregisterResource", CONV_D3D9, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListAll", + ("HIP_D3D10_DEVICE_LIST_ALL", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10DeviceListCurrentFrame", + ( + "HIP_D3D10_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10DeviceListNextFrame", + ( + "HIP_D3D10_DEVICE_LIST_NEXT_FRAME", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDevice", + ("hipD3D10GetDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10GetDevices", + ("hipD3D10GetDevices", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D10RegisterResource", + ( + "hipGraphicsD3D10RegisterResource", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsNone", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_NONE", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsReadOnly", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_READONLY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10MapFlagsWriteDiscard", + ( + "HIP_D3D10_MAPRESOURCE_FLAGS_WRITEDISCARD", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10RegisterFlagsNone", + ("HIP_D3D10_REGISTER_FLAGS_NONE", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterFlagsArray", + ( + "HIP_D3D10_REGISTER_FLAGS_ARRAY", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10GetDirect3DDevice", + ("hipD3D10GetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10MapResources", + ("hipD3D10MapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10RegisterResource", + ("hipD3D10RegisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetMappedArray", + ( + "hipD3D10ResourceGetMappedArray", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPitch", + ( + "hipD3D10ResourceGetMappedPitch", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedPointer", + ( + "hipD3D10ResourceGetMappedPointer", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceGetMappedSize", + ("hipD3D10ResourceGetMappedSize", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10ResourceGetSurfaceDimensions", + ( + "hipD3D10ResourceGetSurfaceDimensions", + CONV_D3D10, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D10ResourceSetMapFlags", + ("hipD3D10ResourceSetMapFlags", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10SetDirect3DDevice", + ("hipD3D10SetDirect3DDevice", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnmapResources", + ("hipD3D10UnmapResources", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D10UnregisterResource", + ("hipD3D10UnregisterResource", CONV_D3D10, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListAll", + ("HIP_D3D11_DEVICE_LIST_ALL", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11DeviceListCurrentFrame", + ( + "HIP_D3D11_DEVICE_LIST_CURRENT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11DeviceListNextFrame", + ( + "HIP_D3D11_DEVICE_LIST_NEXT_FRAME", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaD3D11GetDevice", + ("hipD3D11GetDevice", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaD3D11GetDevices", + ("hipD3D11GetDevices", CONV_D3D11, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsD3D11RegisterResource", + ( + "hipGraphicsD3D11RegisterResource", + CONV_D3D11, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterOutputSurface", + ( + "hipGraphicsVDPAURegisterOutputSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaGraphicsVDPAURegisterVideoSurface", + ( + "hipGraphicsVDPAURegisterVideoSurface", + CONV_VDPAU, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaVDPAUGetDevice", + ("hipVDPAUGetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaVDPAUSetVDPAUDevice", + ("hipVDPAUSetDevice", CONV_VDPAU, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerAcquireFrame", + ( + "hipEGLStreamConsumerAcquireFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerConnect", + ("hipEGLStreamConsumerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamConsumerConnectWithFlags", + ( + "hipEGLStreamConsumerConnectWithFlags", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamConsumerReleaseFrame", + ( + "hipEGLStreamConsumerReleaseFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerConnect", + ("hipEGLStreamProducerConnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerDisconnect", + ("hipEGLStreamProducerDisconnect", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaEGLStreamProducerPresentFrame", + ( + "hipEGLStreamProducerPresentFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ( + "cudaEGLStreamProducerReturnFrame", + ("hipEGLStreamProducerReturnFrame", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsEGLRegisterImage", + ("hipGraphicsEGLRegisterImage", CONV_EGL, API_RUNTIME, HIP_UNSUPPORTED), + ), + ( + "cudaGraphicsResourceGetMappedEglFrame", + ( + "hipGraphicsResourceGetMappedEglFrame", + CONV_EGL, + API_RUNTIME, + HIP_UNSUPPORTED, + ), + ), + ("cublasInit", ("rocblas_init", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasShutdown", + ("rocblas_shutdown", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVersion", + ("rocblas_get_version", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetError", + ("rocblas_get_error", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAlloc", ("rocblas_alloc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasFree", ("rocblas_free", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSetKernelStream", + ("rocblas_set_kernel_stream", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetAtomicsMode", + ("rocblas_get_atomics_mode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetAtomicsMode", + ("rocblas_set_atomics_mode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetMathMode", + ("rocblas_get_math_mode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMathMode", + ("rocblas_set_math_mode", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("CUBLAS_OP_N", ("rocblas_operation_none", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_OP_T", + ("rocblas_operation_transpose", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_OP_C", + ("rocblas_operation_conjugate_transpose", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_SUCCESS", + ("rocblas_status_success", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_INITIALIZED", + ("rocblas_status_invalid_handle", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ALLOC_FAILED", + ("rocblas_status_memory_error", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INVALID_VALUE", + ("rocblas_status_invalid_pointer", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_MAPPING_ERROR", + ("rocblas_status_internal_error", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_EXECUTION_FAILED", + ("rocblas_status_internal_error", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_INTERNAL_ERROR", + ("rocblas_status_internal_error", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_NOT_SUPPORTED", + ("rocblas_status_not_implemented", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_STATUS_ARCH_MISMATCH", + ("rocblas_status_not_implemented", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_LOWER", + ("rocblas_fill_lower", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_FILL_MODE_UPPER", + ("rocblas_fill_upper", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_DIAG_NON_UNIT", + ("rocblas_diagonal_non_unit", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ("CUBLAS_DIAG_UNIT", ("rocblas_diagonal_unit", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_LEFT", ("rocblas_side_left", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUBLAS_SIDE_RIGHT", ("rocblas_side_right", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUBLAS_POINTER_MODE_HOST", + ("rocblas_pointer_mode_host", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_POINTER_MODE_DEVICE", + ("rocblas_pointer_mode_device", CONV_NUMERIC_LITERAL, API_BLAS), + ), + ( + "CUBLAS_ATOMICS_NOT_ALLOWED", + ( + "rocblas_atomics_not_allowed", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_ATOMICS_ALLOWED", + ( + "rocblas_atomics_allowed", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_FLOAT", + ( + "rocblas_precision_float", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_DOUBLE", + ( + "rocblas_precision_double", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "CUBLAS_DATA_HALF", + ("rocblas_precision_half", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "CUBLAS_DATA_INT8", + ("rocblas_precision_int8", CONV_NUMERIC_LITERAL, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCreate", ("rocblas_create_handle", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy", ("rocblas_destroy_handle", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetVector", ("rocblas_set_vector", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetVector", ("rocblas_get_vector", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSetVectorAsync", + ("rocblas_set_vector_async", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasGetVectorAsync", + ("rocblas_get_vector_async", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetMatrix", ("rocblas_set_matrix", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetMatrix", ("rocblas_get_matrix", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetMatrixAsync", + ("rocblas_get_matrix_async", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSetMatrixAsync", + ("rocblas_set_matrix_async", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasXerbla", ("rocblas_xerbla", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSnrm2", ("rocblas_snrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2", ("rocblas_dnrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasScnrm2", ("rocblas_scnrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDznrm2", ("rocblas_dznrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasNrm2Ex", + ("rocblas_nrm2_ex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdot", ("rocblas_sdot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSdotBatched", + ("rocblas_sdot_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDdot", ("rocblas_ddot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDdotBatched", + ("rocblas_ddot_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCdotu", ("rocblas_cdotu", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCdotc", ("rocblas_cdotc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdotu", ("rocblas_zdotu", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdotc", ("rocblas_zdotc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSscal", ("rocblas_sscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSscalBatched", + ("rocblas_sscal_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDscal", ("rocblas_dscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDscalBatched", + ("rocblas_dscal_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCscal", ("rocblas_cscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsscal", ("rocblas_csscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZscal", ("rocblas_zscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdscal", ("rocblas_zdscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy", ("rocblas_saxpy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSaxpyBatched", + ("rocblas_saxpy_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDaxpy", ("rocblas_daxpy", CONV_MATH_FUNC, API_BLAS)), + ("cublasCaxpy", ("rocblas_caxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZaxpy", ("rocblas_zaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasScopy", ("rocblas_scopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScopyBatched", + ("rocblas_scopy_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDcopy", ("rocblas_dcopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDcopyBatched", + ("rocblas_dcopy_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasCcopy", ("rocblas_ccopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZcopy", ("rocblas_zcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSswap", ("rocblas_sswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap", ("rocblas_dswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasCswap", ("rocblas_cswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZswap", ("rocblas_zswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamax", ("rocblas_isamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax", ("rocblas_idamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamax", ("rocblas_icamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamax", ("rocblas_izamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIsamin", ("rocblas_isamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin", ("rocblas_idamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIcamin", ("rocblas_icamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasIzamin", ("rocblas_izamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSasum", ("rocblas_sasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSasumBatched", + ("rocblas_sasum_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDasum", ("rocblas_dasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasDasumBatched", + ("rocblas_dasum_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScasum", ("rocblas_scasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDzasum", ("rocblas_dzasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrot", ("rocblas_srot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot", ("rocblas_drot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot", ("rocblas_crot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsrot", ("rocblas_csrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrot", ("rocblas_zrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdrot", ("rocblas_zdrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotg", ("rocblas_srotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotg", ("rocblas_drotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrotg", ("rocblas_crotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZrotg", ("rocblas_zrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotm", ("rocblas_srotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotm", ("rocblas_drotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSrotmg", ("rocblas_srotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrotmg", ("rocblas_drotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgemv", ("rocblas_sgemv", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasSgemvBatched", + ("rocblas_sgemv_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDgemv", ("rocblas_dgemv", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemv", ("rocblas_cgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgemv", ("rocblas_zgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgbmv", ("rocblas_sgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDgbmv", ("rocblas_dgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgbmv", ("rocblas_cgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgbmv", ("rocblas_zgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrmv", ("rocblas_strmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmv", ("rocblas_dtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmv", ("rocblas_ctrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmv", ("rocblas_ztrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbmv", ("rocblas_stbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbmv", ("rocblas_dtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbmv", ("rocblas_ctbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbmv", ("rocblas_ztbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpmv", ("rocblas_stpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpmv", ("rocblas_dtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpmv", ("rocblas_ctpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpmv", ("rocblas_ztpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsv", ("rocblas_strsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrsv", ("rocblas_dtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrsv", ("rocblas_ctrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsv", ("rocblas_ztrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpsv", ("rocblas_stpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpsv", ("rocblas_dtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpsv", ("rocblas_ctpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpsv", ("rocblas_ztpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStbsv", ("rocblas_stbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtbsv", ("rocblas_dtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtbsv", ("rocblas_ctbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtbsv", ("rocblas_ztbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymv", ("rocblas_ssymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymv", ("rocblas_dsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymv", ("rocblas_csymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymv", ("rocblas_zsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemv", ("rocblas_chemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemv", ("rocblas_zhemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsbmv", ("rocblas_ssbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsbmv", ("rocblas_dsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChbmv", ("rocblas_chbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhbmv", ("rocblas_zhbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspmv", ("rocblas_sspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspmv", ("rocblas_dspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpmv", ("rocblas_chpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpmv", ("rocblas_zhpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSger", ("rocblas_sger", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger", ("rocblas_dger", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeru", ("rocblas_cgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCgerc", ("rocblas_cgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeru", ("rocblas_zgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgerc", ("rocblas_zgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr", ("rocblas_ssyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasDsyr", ("rocblas_dsyr", CONV_MATH_FUNC, API_BLAS)), + ("cublasCher", ("rocblas_cher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher", ("rocblas_zher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr", ("rocblas_sspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr", ("rocblas_dspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr", ("rocblas_chpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr", ("rocblas_zhpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2", ("rocblas_ssyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2", ("rocblas_dsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2", ("rocblas_cher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2", ("rocblas_zher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr2", ("rocblas_sspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr2", ("rocblas_dspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr2", ("rocblas_chpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr2", ("rocblas_zhpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgemmBatched", + ("rocblas_sgemm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgemmBatched", + ("rocblas_dgemm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasHgemmBatched", + ("rocblas_hgemm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgemmStridedBatched", + ("rocblas_sgemm_strided_batched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasDgemmStridedBatched", + ("rocblas_dgemm_strided_batched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasHgemmStridedBatched", + ("rocblas_hgemm_strided_batched", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasCgemmBatched", + ("rocblas_cgemm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mBatched", + ("rocblas_cgemm_3m_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemmBatched", + ("rocblas_zgemm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemmStridedBatched", + ( + "rocblas_cgemm_strided_batched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasCgemm3mStridedBatched", + ( + "rocblas_cgemm_3m_strided_batched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasZgemmStridedBatched", + ( + "rocblas_zgemm_strided_batched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ( + "cublasHgemmStridedBatched", + ( + "rocblas_hgemm_strided_batched", + CONV_MATH_FUNC, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cublasSgemm", ("rocblas_sgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm", ("rocblas_dgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgemm", ("rocblas_cgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasZgemm", ("rocblas_zgemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasHgemm", ("rocblas_hgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasSsyrk", ("rocblas_ssyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrk", ("rocblas_dsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrk", ("rocblas_csyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrk", ("rocblas_zsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherk", ("rocblas_cherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherk", ("rocblas_zherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyr2k", ("rocblas_ssyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr2k", ("rocblas_dsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr2k", ("rocblas_csyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr2k", ("rocblas_zyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsyrkx", ("rocblas_ssyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyrkx", ("rocblas_dsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyrkx", ("rocblas_csyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyrkx", ("rocblas_zsyrkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher2k", ("rocblas_cher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher2k", ("rocblas_zher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCherkx", ("rocblas_cherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZherkx", ("rocblas_zherkx", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSsymm", ("rocblas_ssymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsymm", ("rocblas_dsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsymm", ("rocblas_csymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsymm", ("rocblas_zsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChemm", ("rocblas_chemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhemm", ("rocblas_zhemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrsm", ("rocblas_strsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDtrsm", ("rocblas_dtrsm", CONV_MATH_FUNC, API_BLAS)), + ("cublasCtrsm", ("rocblas_ctrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrsm", ("rocblas_ztrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasStrsmBatched", + ("rocblas_strsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("rocblas_dtrsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("rocblas_ctrsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("rocblas_ztrsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasStrmm", ("rocblas_strmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrmm", ("rocblas_dtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrmm", ("rocblas_ctrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrmm", ("rocblas_ztrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSgeam", ("rocblas_sgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgeam", ("rocblas_dgeam", CONV_MATH_FUNC, API_BLAS)), + ("cublasCgeam", ("rocblas_cgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZgeam", ("rocblas_zgeam", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSgetrfBatched", + ("rocblas_sgetrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrfBatched", + ("rocblas_dgetrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrfBatched", + ("rocblas_cgetrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrfBatched", + ("rocblas_zgetrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetriBatched", + ("rocblas_sgetri_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetriBatched", + ("rocblas_dgetri_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetriBatched", + ("rocblas_cgetri_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetriBatched", + ("rocblas_zgetri_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgetrsBatched", + ("rocblas_sgetrs_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgetrsBatched", + ("rocblas_dgetrs_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgetrsBatched", + ("rocblas_cgetrs_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgetrsBatched", + ("rocblas_zgetrs_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsmBatched", + ("rocblas_strsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsmBatched", + ("rocblas_dtrsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsmBatched", + ("rocblas_ctrsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsmBatched", + ("rocblas_dtrsm_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSmatinvBatched", + ("rocblas_smatinv_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDmatinvBatched", + ("rocblas_dmatinv_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCmatinvBatched", + ("rocblas_cmatinv_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZmatinvBatched", + ("rocblas_zmatinv_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgeqrfBatched", + ("rocblas_sgeqrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgeqrfBatched", + ("rocblas_dgeqrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgeqrfBatched", + ("rocblas_cgeqrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeqrfBatched", + ("rocblas_zgeqrf_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgelsBatched", + ("rocblas_sgels_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgelsBatched", + ("rocblas_dgels_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgelsBatched", + ("rocblas_cgels_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgelsBatched", + ("rocblas_zgels_batched", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSdgmm", ("rocblas_sdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDdgmm", ("rocblas_ddgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCdgmm", ("rocblas_cdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZdgmm", ("rocblas_zdgmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStpttr", ("rocblas_stpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtpttr", ("rocblas_dtpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtpttr", ("rocblas_ctpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtpttr", ("rocblas_ztpttr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasStrttp", ("rocblas_strttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDtrttp", ("rocblas_dtrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCtrttp", ("rocblas_ctrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZtrttp", ("rocblas_ztrttp", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCreate_v2", ("rocblas_create_handle", CONV_MATH_FUNC, API_BLAS)), + ("cublasDestroy_v2", ("rocblas_destroy_handle", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetVersion_v2", + ("rocblas_get_version", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSetStream", ("rocblas_set_stream", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream", ("rocblas_get_stream", CONV_MATH_FUNC, API_BLAS)), + ("cublasSetStream_v2", ("rocblas_set_stream", CONV_MATH_FUNC, API_BLAS)), + ("cublasGetStream_v2", ("rocblas_get_stream", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasGetPointerMode", + ("rocblas_get_pointer_mode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode", + ("rocblas_set_pointer_mode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasGetPointerMode_v2", + ("rocblas_get_pointer_mode", CONV_MATH_FUNC, API_BLAS), + ), + ( + "cublasSetPointerMode_v2", + ("rocblas_set_pointer_mode", CONV_MATH_FUNC, API_BLAS), + ), + ("cublasSgemv_v2", ("rocblas_sgemv", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemv_v2", ("rocblas_dgemv", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemv_v2", + ("rocblas_cgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemv_v2", + ("rocblas_zgemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSgbmv_v2", + ("rocblas_sgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDgbmv_v2", + ("rocblas_dgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgbmv_v2", + ("rocblas_cgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgbmv_v2", + ("rocblas_zgbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmv_v2", + ("rocblas_strmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmv_v2", + ("rocblas_dtrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmv_v2", + ("rocblas_ctrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmv_v2", + ("rocblas_ztrmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbmv_v2", + ("rocblas_stbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbmv_v2", + ("rocblas_dtbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbmv_v2", + ("rocblas_ctbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbmv_v2", + ("rocblas_ztbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpmv_v2", + ("rocblas_stpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpmv_v2", + ("rocblas_dtpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpmv_v2", + ("rocblas_ctpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpmv_v2", + ("rocblas_ztpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsv_v2", + ("rocblas_strsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsv_v2", + ("rocblas_dtrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsv_v2", + ("rocblas_ctrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsv_v2", + ("rocblas_ztrsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStpsv_v2", + ("rocblas_stpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtpsv_v2", + ("rocblas_dtpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtpsv_v2", + ("rocblas_ctpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtpsv_v2", + ("rocblas_ztpsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStbsv_v2", + ("rocblas_stbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtbsv_v2", + ("rocblas_dtbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtbsv_v2", + ("rocblas_ctbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtbsv_v2", + ("rocblas_ztbsv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymv_v2", + ("rocblas_ssymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymv_v2", + ("rocblas_dsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymv_v2", + ("rocblas_csymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymv_v2", + ("rocblas_zsymv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemv_v2", + ("rocblas_chemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemv_v2", + ("rocblas_zhemv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsbmv_v2", + ("rocblas_ssbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsbmv_v2", + ("rocblas_dsbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChbmv_v2", + ("rocblas_chbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhbmv_v2", + ("rocblas_zhbmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspmv_v2", + ("rocblas_sspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspmv_v2", + ("rocblas_dspmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpmv_v2", + ("rocblas_chpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpmv_v2", + ("rocblas_zhpmv", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSger_v2", ("rocblas_sger", CONV_MATH_FUNC, API_BLAS)), + ("cublasDger_v2", ("rocblas_dger", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgeru_v2", + ("rocblas_cgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgerc_v2", + ("rocblas_cergc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgeru_v2", + ("rocblas_zgeru", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgerc_v2", + ("rocblas_zgerc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSsyr_v2", ("rocblas_ssyr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDsyr_v2", ("rocblas_dsyr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCsyr_v2", ("rocblas_csyr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZsyr_v2", ("rocblas_zsyr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCher_v2", ("rocblas_cher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZher_v2", ("rocblas_zher", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSspr_v2", ("rocblas_sspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDspr_v2", ("rocblas_dspr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasChpr_v2", ("rocblas_chpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasZhpr_v2", ("rocblas_zhpr", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasSsyr2_v2", + ("rocblas_ssyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2_v2", + ("rocblas_dsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2_v2", + ("rocblas_csyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2_v2", + ("rocblas_zsyr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2_v2", + ("rocblas_cher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2_v2", + ("rocblas_zher2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSspr2_v2", + ("rocblas_sspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDspr2_v2", + ("rocblas_dspr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChpr2_v2", + ("rocblas_chpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhpr2_v2", + ("rocblas_zhpr2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSgemm_v2", ("rocblas_sgemm", CONV_MATH_FUNC, API_BLAS)), + ("cublasDgemm_v2", ("rocblas_dgemm", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCgemm_v2", + ("rocblas_cgemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3m", + ("rocblas_cgemm_3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCgemm3mEx", + ("rocblas_cgemm_3mex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm_v2", + ("rocblas_zgemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZgemm3m", + ("rocblas_zgemm_3m", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + # NB: The function rocblas_sgemmex doesn't actually exist in + # rocblas, as of 2018-12-05 + ( + "cublasSgemmEx", + ("rocblas_sgemmex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasGemmEx", ("rocblas_gemmex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasCgemmEx", + ("rocblas_cgemmex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasUint8gemmBias", + ("rocblas_uint8gemmbias", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyrk_v2", + ("rocblas_ssyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyrk_v2", + ("rocblas_dsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk_v2", + ("rocblas_csyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyrk_v2", + ("rocblas_zsyrk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrkEx", + ("rocblas_csyrkex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyrk3mEx", + ("rocblas_csyrk3mex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk_v2", + ("rocblas_cherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherkEx", + ("rocblas_cherkex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCherk3mEx", + ("rocblas_cherk3mex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZherk_v2", + ("rocblas_zherk", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsyr2k_v2", + ("rocblas_ssyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsyr2k_v2", + ("rocblas_dsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsyr2k_v2", + ("rocblas_csyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsyr2k_v2", + ("rocblas_zsyr2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCher2k_v2", + ("rocblas_cher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZher2k_v2", + ("rocblas_zher2k", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSsymm_v2", + ("rocblas_ssymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDsymm_v2", + ("rocblas_dsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsymm_v2", + ("rocblas_csymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZsymm_v2", + ("rocblas_zsymm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasChemm_v2", + ("rocblas_chemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZhemm_v2", + ("rocblas_zhemm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrsm_v2", + ("rocblas_strsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrsm_v2", + ("rocblas_dtrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrsm_v2", + ("rocblas_ctrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrsm_v2", + ("rocblas_ztrsm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasStrmm_v2", + ("rocblas_strmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDtrmm_v2", + ("rocblas_dtrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCtrmm_v2", + ("rocblas_ctrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZtrmm_v2", + ("rocblas_ztrmm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSnrm2_v2", ("rocblas_snrm2", CONV_MATH_FUNC, API_BLAS)), + ("cublasDnrm2_v2", ("rocblas_dnrm2", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScnrm2_v2", + ("rocblas_scnrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDznrm2_v2", + ("rocblas_dznrm2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasDotEx", ("rocblas_dotex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDotcEx", ("rocblas_dotcex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSdot_v2", ("rocblas_sdot", CONV_MATH_FUNC, API_BLAS)), + ("cublasDdot_v2", ("rocblas_ddot", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCdotu_v2", + ("rocblas_cdotu", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCdotc_v2", + ("rocblas_cdotc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotu_v2", + ("rocblas_zdotu", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdotc_v2", + ("rocblas_zdotc", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScalEx", ("rocblas_scalex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSscal_v2", ("rocblas_sscal", CONV_MATH_FUNC, API_BLAS)), + ("cublasDscal_v2", ("rocblas_dscal", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCscal_v2", + ("rocblas_cscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCsscal_v2", + ("rocblas_csscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZscal_v2", + ("rocblas_zcsal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZdscal_v2", + ("rocblas_zdscal", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasAxpyEx", ("rocblas_axpyex", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasSaxpy_v2", ("rocblas_saxpy", CONV_MATH_FUNC, API_BLAS)), + ("cublasDaxpy_v2", ("rocblas_daxpy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCaxpy_v2", + ("rocblas_caxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZaxpy_v2", + ("rocblas_zaxpy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasScopy_v2", ("rocblas_scopy", CONV_MATH_FUNC, API_BLAS)), + ("cublasDcopy_v2", ("rocblas_dcopy", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCcopy_v2", + ("rocblas_ccopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZcopy_v2", + ("rocblas_zcopy", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSswap_v2", ("rocblas_sswap", CONV_MATH_FUNC, API_BLAS)), + ("cublasDswap_v2", ("rocblas_dswap", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasCswap_v2", + ("rocblas_cswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZswap_v2", + ("rocblas_zswap", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamax_v2", ("rocblas_isamax", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamax_v2", ("rocblas_idamax", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamax_v2", + ("rocblas_icamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamax_v2", + ("rocblas_izamax", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasIsamin_v2", ("rocblas_isamin", CONV_MATH_FUNC, API_BLAS)), + ("cublasIdamin_v2", ("rocblas_idamin", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasIcamin_v2", + ("rocblas_icamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasIzamin_v2", + ("rocblas_izamin", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSasum_v2", ("rocblas_sasum", CONV_MATH_FUNC, API_BLAS)), + ("cublasDasum_v2", ("rocblas_dasum", CONV_MATH_FUNC, API_BLAS)), + ( + "cublasScasum_v2", + ("rocblas_scasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDzasum_v2", + ("rocblas_dzasum", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasSrot_v2", ("rocblas_srot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasDrot_v2", ("rocblas_drot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ("cublasCrot_v2", ("rocblas_crot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasCsrot_v2", + ("rocblas_csrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ("cublasZrot_v2", ("rocblas_zrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED)), + ( + "cublasZdrot_v2", + ("rocblas_zdrot", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotg_v2", + ("rocblas_srotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotg_v2", + ("rocblas_drotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasCrotg_v2", + ("rocblas_crotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasZrotg_v2", + ("rocblas_zrotg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotm_v2", + ("rocblas_srotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotm_v2", + ("rocblas_drotm", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasSrotmg_v2", + ("rocblas_srotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "cublasDrotmg_v2", + ("rocblas_drotmg", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED), + ), + ( + "CURAND_STATUS_SUCCESS", + ("HIPRAND_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_VERSION_MISMATCH", + ("HIPRAND_STATUS_VERSION_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_NOT_INITIALIZED", + ("HIPRAND_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ALLOCATION_FAILED", + ("HIPRAND_STATUS_ALLOCATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_TYPE_ERROR", + ("HIPRAND_STATUS_TYPE_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_OUT_OF_RANGE", + ("HIPRAND_STATUS_OUT_OF_RANGE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_LENGTH_NOT_MULTIPLE", + ("HIPRAND_STATUS_LENGTH_NOT_MULTIPLE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED", + ( + "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED", + CONV_NUMERIC_LITERAL, + API_RAND, + ), + ), + ( + "CURAND_STATUS_LAUNCH_FAILURE", + ("HIPRAND_STATUS_LAUNCH_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_PREEXISTING_FAILURE", + ("HIPRAND_STATUS_PREEXISTING_FAILURE", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INITIALIZATION_FAILED", + ("HIPRAND_STATUS_INITIALIZATION_FAILED", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_ARCH_MISMATCH", + ("HIPRAND_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_STATUS_INTERNAL_ERROR", + ("HIPRAND_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_RAND), + ), + ("CURAND_RNG_TEST", ("HIPRAND_RNG_TEST", CONV_NUMERIC_LITERAL, API_RAND)), + ( + "mtgp32dc_params_fast_11213", + ("mtgp32dc_params_fast_11213", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_DEFAULT", + ("HIPRAND_RNG_PSEUDO_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_XORWOW", + ("HIPRAND_RNG_PSEUDO_XORWOW", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MRG32K3A", + ("HIPRAND_RNG_PSEUDO_MRG32K3A", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MTGP32", + ("HIPRAND_RNG_PSEUDO_MTGP32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_MT19937", + ("HIPRAND_RNG_PSEUDO_MT19937", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_PSEUDO_PHILOX4_32_10", + ("HIPRAND_RNG_PSEUDO_PHILOX4_32_10", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_DEFAULT", + ("HIPRAND_RNG_QUASI_DEFAULT", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL32", + ("HIPRAND_RNG_QUASI_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL32", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL32", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SOBOL64", + ("HIPRAND_RNG_QUASI_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "CURAND_RNG_QUASI_SCRAMBLED_SOBOL64", + ("HIPRAND_RNG_QUASI_SCRAMBLED_SOBOL64", CONV_NUMERIC_LITERAL, API_RAND), + ), + ( + "curand_ORDERING_PSEUDO_BEST", + ( + "HIPRAND_ORDERING_PSEUDO_BEST", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_DEFAULT", + ( + "HIPRAND_ORDERING_PSEUDO_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_PSEUDO_SEEDED", + ( + "HIPRAND_ORDERING_PSEUDO_SEEDED", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_ORDERING_QUASI_DEFAULT", + ( + "HIPRAND_ORDERING_QUASI_DEFAULT", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_32_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + ( + "HIPRAND_SCRAMBLED_DIRECTION_VECTORS_64_JOEKUO6", + CONV_NUMERIC_LITERAL, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_CHOOSE_BEST", + ("HIPRAND_CHOOSE_BEST", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_ITR", + ("HIPRAND_ITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_KNUTH", + ("HIPRAND_KNUTH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_HITR", + ("HIPRAND_HITR", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_M1", ("HIPRAND_M1", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ("curand_M2", ("HIPRAND_M2", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED)), + ( + "curand_BINARY_SEARCH", + ("HIPRAND_BINARY_SEARCH", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DISCRETE_GAUSS", + ("HIPRAND_DISCRETE_GAUSS", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_REJECTION", + ("HIPRAND_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEVICE_API", + ("HIPRAND_DEVICE_API", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_FAST_REJECTION", + ("HIPRAND_FAST_REJECTION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_3RD", + ("HIPRAND_3RD", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_DEFINITION", + ("HIPRAND_DEFINITION", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_POISSON", + ("HIPRAND_POISSON", CONV_NUMERIC_LITERAL, API_RAND, HIP_UNSUPPORTED), + ), + ("curandCreateGenerator", ("hiprandCreateGenerator", CONV_MATH_FUNC, API_RAND)), + ( + "curandCreateGeneratorHost", + ("hiprandCreateGeneratorHost", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandCreatePoissonDistribution", + ("hiprandCreatePoissonDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyDistribution", + ("hiprandDestroyDistribution", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandDestroyGenerator", + ("hiprandDestroyGenerator", CONV_MATH_FUNC, API_RAND), + ), + ("curandGenerate", ("hiprandGenerate", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateLogNormal", + ("hiprandGenerateLogNormal", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLogNormalDouble", + ("hiprandGenerateLogNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGenerateLongLong", + ("hiprandGenerateLongLong", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curandGenerateNormal", ("hiprandGenerateNormal", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateNormalDouble", + ("hiprandGenerateNormalDouble", CONV_MATH_FUNC, API_RAND), + ), + ("curandGeneratePoisson", ("hiprandGeneratePoisson", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateSeeds", ("hiprandGenerateSeeds", CONV_MATH_FUNC, API_RAND)), + ("curandGenerateUniform", ("hiprandGenerateUniform", CONV_MATH_FUNC, API_RAND)), + ( + "curandGenerateUniformDouble", + ("hiprandGenerateUniformDouble", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandGetDirectionVectors32", + ("hiprandGetDirectionVectors32", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetDirectionVectors64", + ("hiprandGetDirectionVectors64", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetProperty", + ("hiprandGetProperty", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandGetScrambleConstants32", + ( + "hiprandGetScrambleConstants32", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curandGetScrambleConstants64", + ( + "hiprandGetScrambleConstants64", + CONV_MATH_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ("curandGetVersion", ("hiprandGetVersion", CONV_MATH_FUNC, API_RAND)), + ( + "curandSetGeneratorOffset", + ("hiprandSetGeneratorOffset", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetGeneratorOrdering", + ("hiprandSetGeneratorOrdering", CONV_MATH_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curandSetPseudoRandomGeneratorSeed", + ("hiprandSetPseudoRandomGeneratorSeed", CONV_MATH_FUNC, API_RAND), + ), + ( + "curandSetQuasiRandomGeneratorDimensions", + ("hiprandSetQuasiRandomGeneratorDimensions", CONV_MATH_FUNC, API_RAND), + ), + ("curandSetStream", ("hiprandSetStream", CONV_MATH_FUNC, API_RAND)), + ("curand", ("hiprand", CONV_DEVICE_FUNC, API_RAND)), + ("curand4", ("hiprand4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_init", ("hiprand_init", CONV_DEVICE_FUNC, API_RAND)), + ("curand_log_normal", ("hiprand_log_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal_double", + ("hiprand_log_normal_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal2", ("hiprand_log_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal2_double", + ("hiprand_log_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_log_normal4", ("hiprand_log_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_log_normal4_double", + ("hiprand_log_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_mtgp32_single", + ("hiprand_mtgp32_single", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ( + "curand_mtgp32_single_specific", + ( + "hiprand_mtgp32_single_specific", + CONV_DEVICE_FUNC, + API_RAND, + HIP_UNSUPPORTED, + ), + ), + ( + "curand_mtgp32_specific", + ("hiprand_mtgp32_specific", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("curand_normal", ("hiprand_normal", CONV_DEVICE_FUNC, API_RAND)), + ( + "curandMakeMTGP32Constants", + ("hiprandMakeMTGP32Constants", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curandMakeMTGP32KernelState", + ("hiprandMakeMTGP32KernelState", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal_double", ("hiprand_normal_double", CONV_DEVICE_FUNC, API_RAND)), + ("curand_normal2", ("hiprand_normal2", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal2_double", + ("hiprand_normal2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_normal4", ("hiprand_normal4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_normal4_double", + ("hiprand_normal4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform", ("hiprand_uniform", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform_double", + ("hiprand_uniform_double", CONV_DEVICE_FUNC, API_RAND), + ), + ( + "curand_uniform2_double", + ("hiprand_uniform2_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_uniform4", ("hiprand_uniform4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_uniform4_double", + ("hiprand_uniform4_double", CONV_DEVICE_FUNC, API_RAND), + ), + ("curand_discrete", ("hiprand_discrete", CONV_DEVICE_FUNC, API_RAND)), + ("curand_discrete4", ("hiprand_discrete4", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson", ("hiprand_poisson", CONV_DEVICE_FUNC, API_RAND)), + ("curand_poisson4", ("hiprand_poisson4", CONV_DEVICE_FUNC, API_RAND)), + ( + "curand_Philox4x32_10", + ("hiprand_Philox4x32_10", CONV_DEVICE_FUNC, API_RAND, HIP_UNSUPPORTED), + ), + ("mtgp32_kernel_params", ("mtgp32_kernel_params_t", CONV_MATH_FUNC, API_RAND)), + ("CUFFT_FORWARD", ("HIPFFT_FORWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ("CUFFT_INVERSE", ("HIPFFT_BACKWARD", CONV_NUMERIC_LITERAL, API_BLAS)), + ( + "CUFFT_COMPATIBILITY_DEFAULT", + ( + "HIPFFT_COMPATIBILITY_DEFAULT", + CONV_NUMERIC_LITERAL, + API_BLAS, + HIP_UNSUPPORTED, + ), + ), + ("cufftResult_t", ("hipfftResult_t", CONV_TYPE, API_FFT)), + ("cufftResult", ("hipfftResult", CONV_TYPE, API_FFT)), + ("CUFFT_SUCCESS", ("HIPFFT_SUCCESS", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_PLAN", ("HIPFFT_INVALID_PLAN", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_ALLOC_FAILED", ("HIPFFT_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_TYPE", ("HIPFFT_INVALID_TYPE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_INVALID_VALUE", + ("HIPFFT_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INTERNAL_ERROR", + ("HIPFFT_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_EXEC_FAILED", ("HIPFFT_EXEC_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_SETUP_FAILED", ("HIPFFT_SETUP_FAILED", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_INVALID_SIZE", ("HIPFFT_INVALID_SIZE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_UNALIGNED_DATA", + ("HIPFFT_UNALIGNED_DATA", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INCOMPLETE_PARAMETER_LIST", + ("HIPFFT_INCOMPLETE_PARAMETER_LIST", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_INVALID_DEVICE", + ("HIPFFT_INVALID_DEVICE", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("CUFFT_PARSE_ERROR", ("HIPFFT_PARSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_NO_WORKSPACE", ("HIPFFT_NO_WORKSPACE", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "CUFFT_NOT_IMPLEMENTED", + ("HIPFFT_NOT_IMPLEMENTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ( + "CUFFT_LICENSE_ERROR", + ("HIPFFT_LICENSE_ERROR", CONV_NUMERIC_LITERAL, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_NOT_SUPPORTED", + ("HIPFFT_NOT_SUPPORTED", CONV_NUMERIC_LITERAL, API_FFT), + ), + ("cufftType_t", ("hipfftType_t", CONV_TYPE, API_FFT)), + ("cufftType", ("hipfftType", CONV_TYPE, API_FFT)), + ("CUFFT_R2C", ("HIPFFT_R2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2R", ("HIPFFT_C2R", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_C2C", ("HIPFFT_C2C", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_D2Z", ("HIPFFT_D2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2D", ("HIPFFT_Z2D", CONV_NUMERIC_LITERAL, API_FFT)), + ("CUFFT_Z2Z", ("HIPFFT_Z2Z", CONV_NUMERIC_LITERAL, API_FFT)), + ( + "cufftCompatibility_t", + ("hipfftCompatibility_t", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "cufftCompatibility", + ("hipfftCompatibility", CONV_TYPE, API_FFT, HIP_UNSUPPORTED), + ), + ( + "CUFFT_COMPATIBILITY_FFTW_PADDING", + ( + "HIPFFT_COMPATIBILITY_FFTW_PADDING", + CONV_NUMERIC_LITERAL, + API_FFT, + HIP_UNSUPPORTED, + ), + ), + ("cufftReal", ("hipfftReal", CONV_TYPE, API_FFT)), + ("cufftDoubleReal", ("hipfftDoubleReal", CONV_TYPE, API_FFT)), + ("cufftComplex", ("hipfftComplex", CONV_TYPE, API_FFT)), + ("cufftDoubleComplex", ("hipfftDoubleComplex", CONV_TYPE, API_FFT)), + ("cufftHandle", ("hipfftHandle", CONV_TYPE, API_FFT)), + ("cufftPlan1d", ("hipfftPlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan2d", ("hipfftPlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlan3d", ("hipfftPlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftPlanMany", ("hipfftPlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan1d", ("hipfftMakePlan1d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan2d", ("hipfftMakePlan2d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlan3d", ("hipfftMakePlan3d", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany", ("hipfftMakePlanMany", CONV_MATH_FUNC, API_FFT)), + ("cufftMakePlanMany64", ("hipfftMakePlanMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany64", ("hipfftGetSizeMany64", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate1d", ("hipfftEstimate1d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate2d", ("hipfftEstimate2d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimate3d", ("hipfftEstimate3d", CONV_MATH_FUNC, API_FFT)), + ("cufftEstimateMany", ("hipfftEstimateMany", CONV_MATH_FUNC, API_FFT)), + ("cufftCreate", ("hipfftCreate", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize1d", ("hipfftGetSize1d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize2d", ("hipfftGetSize2d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize3d", ("hipfftGetSize3d", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSizeMany", ("hipfftGetSizeMany", CONV_MATH_FUNC, API_FFT)), + ("cufftGetSize", ("hipfftGetSize", CONV_MATH_FUNC, API_FFT)), + ("cufftSetWorkArea", ("hipfftSetWorkArea", CONV_MATH_FUNC, API_FFT)), + ( + "cufftSetAutoAllocation", + ("hipfftSetAutoAllocation", CONV_MATH_FUNC, API_FFT), + ), + ("cufftExecC2C", ("hipfftExecC2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecR2C", ("hipfftExecR2C", CONV_MATH_FUNC, API_FFT)), + ("cufftExecC2R", ("hipfftExecC2R", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2Z", ("hipfftExecZ2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecD2Z", ("hipfftExecD2Z", CONV_MATH_FUNC, API_FFT)), + ("cufftExecZ2D", ("hipfftExecZ2D", CONV_MATH_FUNC, API_FFT)), + ("cufftSetStream", ("hipfftSetStream", CONV_MATH_FUNC, API_FFT)), + ("cufftDestroy", ("hipfftDestroy", CONV_MATH_FUNC, API_FFT)), + ("cufftGetVersion", ("hipfftGetVersion", CONV_MATH_FUNC, API_FFT)), + ( + "cufftGetProperty", + ("hipfftGetProperty", CONV_MATH_FUNC, API_FFT, HIP_UNSUPPORTED), + ), + ("nvrtcResult", ("hiprtcResult", CONV_TYPE, API_RTC)), + ("NVRTC_SUCCESS", ("HIPRTC_SUCCESS", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_OUT_OF_MEMORY", + ("HIPRTC_ERROR_OUT_OF_MEMORY", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_PROGRAM_CREATION_FAILURE", + ("HIPRTC_ERROR_PROGRAM_CREATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_INPUT", + ("HIPRTC_ERROR_INVALID_INPUT", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INVALID_PROGRAM", + ("HIPRTC_ERROR_INVALID_PROGRAM", CONV_TYPE, API_RTC), + ), + ("NVRTC_ERROR_COMPILATION", ("HIPRTC_ERROR_COMPILATION", CONV_TYPE, API_RTC)), + ( + "NVRTC_ERROR_BUILTIN_OPERATION_FAILURE", + ("HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", + ("HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID", + ("HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID", CONV_TYPE, API_RTC), + ), + ( + "NVRTC_ERROR_INTERNAL_ERROR", + ("HIPRTC_ERROR_INTERNAL_ERROR", CONV_TYPE, API_RTC), + ), + ("nvrtcGetErrorString", ("hiprtcGetErrorString", CONV_JIT, API_RTC)), + ("nvrtcVersion", ("hiprtcVersion", CONV_JIT, API_RTC)), + ("nvrtcProgram", ("hiprtcProgram", CONV_TYPE, API_RTC)), + ("nvrtcAddNameExpression", ("hiprtcAddNameExpression", CONV_JIT, API_RTC)), + ("nvrtcCompileProgram", ("hiprtcCompileProgram", CONV_JIT, API_RTC)), + ("nvrtcCreateProgram", ("hiprtcCreateProgram", CONV_JIT, API_RTC)), + ("nvrtcDestroyProgram", ("hiprtcDestroyProgram", CONV_JIT, API_RTC)), + ("nvrtcGetLoweredName", ("hiprtcGetLoweredName", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLog", ("hiprtcGetProgramLog", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLogSize", ("hiprtcGetProgramLogSize", CONV_JIT, API_RTC)), + ("nvrtcGetPTX", ("hiprtcGetCode", CONV_JIT, API_RTC)), + ("nvrtcGetPTXSize", ("hiprtcGetCodeSize", CONV_JIT, API_RTC)), + ("thrust::cuda", ("thrust::hip", CONV_MATH_FUNC, API_BLAS)), + ("cub::", ("hipcub::", CONV_MATH_FUNC, API_BLAS)), + ("nvtxMark", ("roctxMark", CONV_OTHER, API_ROCTX)), + ("nvtxMarkA", ("roctxMarkA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePushA", ("roctxRangePushA", CONV_OTHER, API_ROCTX)), + ("nvtxRangePop", ("roctxRangePop", CONV_OTHER, API_ROCTX)), + ] +) + +CUDA_SPARSE_MAP = collections.OrderedDict( + [ + ("cusparseStatus_t", ("hipsparseStatus_t", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseHandle_t", ("hipsparseHandle_t", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseOperation_t", ("hipsparseOperation_t", CONV_TYPE, API_SPARSE)), + ( + "cusparseCreateMatDescr", + ("hipsparseCreateMatDescr", CONV_MATH_FUNC, API_SPARSE), + ), + ("cusparseCreate", ("hipsparseCreate", CONV_MATH_FUNC, API_SPARSE)), + ( + "cusparseDestroyMatDescr", + ("hipsparseDestroyMatDescr", CONV_MATH_FUNC, API_SPARSE), + ), + ("cusparseDestroy", ("hipsparseDestroy", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseXcoo2csr", ("hipsparseXcoo2csr", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseMatDescr_t", ("hipsparseMatDescr_t", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseScsrmm2", ("hipsparseScsrmm2", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseDcsrmm2", ("hipsparseDcsrmm2", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseScsrmm", ("hipsparseScsrmm", CONV_MATH_FUNC, API_SPARSE)), + ("cusparseDcsrmm", ("hipsparseDcsrmm", CONV_MATH_FUNC, API_SPARSE)), + ( + "cusparseXcsrsort_bufferSizeExt", + ("hipsparseXcsrsort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE), + ), + ("cusparseXcsrsort", ("hipsparseXcsrsort", CONV_MATH_FUNC, API_SPARSE)), + ( + "cusparseXcoosort_bufferSizeExt", + ("hipsparseXcoosort_bufferSizeExt", CONV_MATH_FUNC, API_SPARSE), + ), + ( + "cusparseXcoosortByRow", + ("hipsparseXcoosortByRow", CONV_MATH_FUNC, API_SPARSE), + ), + ("cusparseSetStream", ("hipsparseSetStream", CONV_MATH_FUNC, API_SPARSE)), + ( + "cusparseCreateIdentityPermutation", + ("hipsparseCreateIdentityPermutation", CONV_MATH_FUNC, API_SPARSE), + ), + ( + "cusparseSetMatIndexBase", + ("hipsparseSetMatIndexBase", CONV_MATH_FUNC, API_SPARSE), + ), + ("cusparseSetMatType", ("hipsparseSetMatType", CONV_MATH_FUNC, API_SPARSE)), + ( + "CUSPARSE_STATUS_SUCCESS", + ("HIPSPARSE_STATUS_SUCCESS", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_NOT_INITIALIZED", + ("HIPSPARSE_STATUS_NOT_INITIALIZED", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_ALLOC_FAILED", + ("HIPSPARSE_STATUS_ALLOC_FAILED", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_INVALID_VALUE", + ("HIPSPARSE_STATUS_INVALID_VALUE", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_MAPPING_ERROR", + ("HIPSPARSE_STATUS_MAPPING_ERROR", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_EXECUTION_FAILED", + ("HIPSPARSE_STATUS_EXECUTION_FAILED", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_INTERNAL_ERROR", + ("HIPSPARSE_STATUS_INTERNAL_ERROR", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + ( + "HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED", + CONV_NUMERIC_LITERAL, + API_SPARSE, + ), + ), + ( + "CUSPARSE_STATUS_ARCH_MISMATCH", + ("HIPSPARSE_STATUS_ARCH_MISMATCH", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_STATUS_ZERO_PIVOT", + ("HIPSPARSE_STATUS_ZERO_PIVOT", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_OPERATION_TRANSPOSE", + ("HIPSPARSE_OPERATION_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_OPERATION_NON_TRANSPOSE", + ("HIPSPARSE_OPERATION_NON_TRANSPOSE", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + ( + "HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE", + CONV_NUMERIC_LITERAL, + API_SPARSE, + ), + ), + ( + "CUSPARSE_INDEX_BASE_ZERO", + ("HIPSPARSE_INDEX_BASE_ZERO", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_INDEX_BASE_ONE", + ("HIPSPARSE_INDEX_BASE_ONE", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ( + "CUSPARSE_MATRIX_TYPE_GENERAL", + ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPARSE), + ), + ] +) + +PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("USE_CUDA", ("USE_ROCM", API_PYTORCH)), + ("CUDA_VERSION", ("HIP_VERSION", API_PYTORCH)), + ("cudaHostAllocator", ("hipHostAllocator", API_PYTORCH)), + ("cudaDeviceAllocator", ("hipDeviceAllocator", API_PYTORCH)), + ("define MAX_NUM_BLOCKS 200", ("define MAX_NUM_BLOCKS 64", API_PYTORCH)), + ("cuda::CUDAGuard", ("hip::HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAGuard", ("HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAGuard", + ("hip::OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("OptionalCUDAGuard", ("OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::CUDAStreamGuard", + ("hip::HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ("CUDAStreamGuard", ("HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::OptionalCUDAStreamGuard", + ("hip::OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "OptionalCUDAStreamGuard", + ("OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH), + ), + # Only get needs to be transformed this way; all the other ones can go + # straight to the normal versions hip::HIPCachingAllocator + ( + "cuda::CUDACachingAllocator::get", + ("hip::HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "CUDACachingAllocator::get", + ("HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH), + ), + ( + "cuda::CUDACachingAllocator::recordStream", + ( + "hip::HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ( + "CUDACachingAllocator::recordStream", + ( + "HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA", + API_PYTORCH, + ), + ), + ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getStreamFromPool", + ("hip::getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH), + ), + ("getStreamFromPool", ("getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH)), + ( + "cuda::getDefaultCUDAStream", + ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getDefaultCUDAStream", + ("getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::getCurrentCUDAStream", + ("hip::getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "getCurrentCUDAStream", + ("getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "cuda::setCurrentCUDAStream", + ("hip::setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + ( + "setCurrentCUDAStream", + ("setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH), + ), + # TODO: Undo this special-case; see the header for motivation behind this + # hack. It's VERY important this is only applied to PyTorch HIPify. + ( + "c10/cuda/CUDAGuard.h", + ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDACachingAllocator.h", + ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDAStream.h", + ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH), + ), + ("gloo/cuda.h", ("gloo/hip.h", API_PYTORCH)), + ( + "gloo/cuda_allreduce_halving_doubling.h", + ("gloo/hip_allreduce_halving_doubling.h", API_PYTORCH), + ), + ( + "gloo/cuda_allreduce_halving_doubling_pipelined.h", + ("gloo/hip_allreduce_halving_doubling_pipelined.h", API_PYTORCH), + ), + ("gloo/cuda_allreduce_ring.h", ("gloo/hip_allreduce_ring.h", API_PYTORCH)), + ( + "gloo/cuda_broadcast_one_to_all.h", + ("gloo/hip_broadcast_one_to_all.h", API_PYTORCH), + ), + ( + "gloo::CudaAllreduceHalvingDoublingPipelined", + ("gloo::HipAllreduceHalvingDoublingPipelined", API_PYTORCH), + ), + ("gloo::CudaBroadcastOneToAll", ("gloo::HipBroadcastOneToAll", API_PYTORCH)), + ("gloo::CudaHostWorkspace", ("gloo::HipHostWorkspace", API_PYTORCH)), + ("gloo::CudaDeviceWorkspace", ("gloo::HipDeviceWorkspace", API_PYTORCH)), + ] +) + +CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict( + [ + ("cuda_stream", ("hip_stream", API_CAFFE2)), + # if the header is a native hip folder (under hip directory), + # there is no need to add a hip path to it; the trie in hipify script + # takes this mapping order to forbid further replacement + ("/hip/", ("/hip/", API_CAFFE2)), + ("/context_gpu", ("/hip/context_gpu", API_CAFFE2)), + ("/common_gpu", ("/hip/common_gpu", API_CAFFE2)), + ("/mixed_utils", ("/hip/mixed_utils", API_CAFFE2)), + ("/operator_fallback_gpu", ("/hip/operator_fallback_gpu", API_CAFFE2)), + ( + "/spatial_batch_norm_op_impl", + ("/hip/spatial_batch_norm_op_impl", API_CAFFE2), + ), + ( + "/recurrent_network_executor_gpu", + ("/hip/recurrent_network_executor_gpu", API_CAFFE2), + ), + ( + "/generate_proposals_op_util_nms_gpu", + ("/hip/generate_proposals_op_util_nms_gpu", API_CAFFE2), + ), + ("/max_pool_with_index_gpu", ("/hip/max_pool_with_index_gpu", API_CAFFE2)), + ("/THCCachingAllocator_gpu", ("/hip/THCCachingAllocator_gpu", API_CAFFE2)), + ("/top_k_heap_selection", ("/hip/top_k_heap_selection", API_CAFFE2)), + ("/top_k_radix_selection", ("/hip/top_k_radix_selection", API_CAFFE2)), + ("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)), + ("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)), + ("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)), + ("/math/reduce.cuh", ("/math/hip/reduce.cuh", API_CAFFE2)), + ("/gather_op.cuh", ("/hip/gather_op.cuh", API_CAFFE2)), + ("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)), + ("REGISTER_CUDA_OPERATOR", ("REGISTER_HIP_OPERATOR", API_CAFFE2)), + ("CUDA_1D_KERNEL_LOOP", ("HIP_1D_KERNEL_LOOP", API_CAFFE2)), + ("CUDAContext", ("HIPContext", API_CAFFE2)), + ("CAFFE_CUDA_NUM_THREADS", ("CAFFE_HIP_NUM_THREADS", API_CAFFE2)), + ("HasCudaGPU", ("HasHipGPU", API_CAFFE2)), + ("__expf", ("expf", API_CAFFE2)), + ("CUBLAS_ENFORCE", ("ROCBLAS_ENFORCE", API_CAFFE2)), + ("CUBLAS_CHECK", ("ROCBLAS_CHECK", API_CAFFE2)), + ("cublas_handle", ("rocblashandle", API_CAFFE2)), + ("CURAND_ENFORCE", ("HIPRAND_ENFORCE", API_CAFFE2)), + ("CURAND_CHECK", ("HIPRAND_CHECK", API_CAFFE2)), + ("curandGenerateUniform", ("hiprandGenerateUniform", API_CAFFE2)), + ("curand_generator", ("hiprand_generator", API_CAFFE2)), + ("CaffeCudaGetDevice", ("CaffeHipGetDevice", API_CAFFE2)), + ("CUDA", ("HIP", API_CAFFE2)), + ("Cuda", ("Hip", API_CAFFE2)), + ("cuda_", ("hip_", API_CAFFE2)), + ("_cuda", ("_hip", API_CAFFE2)), + ("CUDNN", ("MIOPEN", API_CAFFE2)), + ("CuDNN", ("MIOPEN", API_CAFFE2)), + ("cudnn", ("miopen", API_CAFFE2)), + ("namespace cuda", ("namespace hip", API_CAFFE2)), + ("cuda::CUDAGuard", ("hip::HIPGuard", API_CAFFE2)), + ("cuda::OptionalCUDAGuard", ("hip::OptionalHIPGuard", API_CAFFE2)), + ("cuda::CUDAStreamGuard", ("hip::HIPStreamGuard", API_CAFFE2)), + ("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuard", API_CAFFE2)), + ("c10/cuda/CUDAGuard.h", ("c10/hip/HIPGuard.h", API_CAFFE2)), + ("gloo/cuda", ("gloo/hip", API_CAFFE2)), + ] +) + +# We must tread very carefully here. Blanket conversions like are done +# in CAFFE2_SPECIFIC_MAPPINGS are not presently supported on PyTorch, +# because a regex for CUDA will also match a filename like CUDAGuard.h, +# but the HIPIFY script doesn't presently move the file and so the substitution +# will be invalid. Instead, we specifically list out every identifier +# and file from c10/cuda which may be used externally, and do substitutions this +# way. +# +# NB: if you want a transformation to ONLY apply to the c10/ directory, +# put it as API_CAFFE2 +C10_MAPPINGS = collections.OrderedDict( + [ + ("cuda::compat::", ("hip::compat::", API_C10)), + ("c10/cuda/CUDAException.h", ("c10/hip/HIPException.h", API_C10)), + ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), + ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), + ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), + ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), + ("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)), + ("c10/cuda/impl/CUDATest.h", ("c10/hip/impl/HIPTest.h", API_C10)), + ("c10/cuda/impl/CUDAGuardImpl.h", ("c10/hip/impl/HIPGuardImpl.h", API_C10)), + ( + "c10/cuda/impl/cuda_cmake_macros.h", + ("c10/hip/impl/hip_cmake_macros.h", API_C10), + ), + ("C10_CUDA_CHECK", ("C10_HIP_CHECK", API_C10)), + ("C10_CUDA_CHECK_WARN", ("C10_HIP_CHECK_WARN", API_C10)), + ("c10::cuda", ("c10::hip", API_C10)), + ("cuda::CUDAStream", ("hip::HIPStream", API_C10)), + ("CUDAStream", ("HIPStream", API_C10)), + # This substitution is not permissible, because there's another copy of this + # function in torch/cuda.h + # ("cuda::device_count", ("hip::device_count", API_C10)), + ("cuda::current_device", ("hip::current_device", API_C10)), + ("cuda::set_device", ("hip::set_device", API_C10)), + ("cuda::getStreamFromPool", ("hip::getStreamFromPool", API_C10)), + ("getStreamFromPool", ("getStreamFromPool", API_C10)), + ("cuda::getDefaultCUDAStream", ("hip::getDefaultHIPStream", API_C10)), + ("getDefaultCUDAStream", ("getDefaultHIPStream", API_C10)), + ("cuda::getCurrentCUDAStream", ("hip::getCurrentHIPStream", API_C10)), + ("getCurrentCUDAStream", ("getCurrentHIPStream", API_C10)), + ("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStream", API_C10)), + ("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)), + ("cuda::CUDACachingAllocator", ("hip::HIPCachingAllocator", API_C10)), + ("CUDACachingAllocator", ("HIPCachingAllocator", API_C10)), + ] +) + +# NB: C10 mappings are more specific than Caffe2 mappings, so run them +# first +CUDA_TO_HIP_MAPPINGS = [ + CUDA_IDENTIFIER_MAP, + CUDA_TYPE_NAME_MAP, + CUDA_INCLUDE_MAP, + CUDA_SPARSE_MAP, + C10_MAPPINGS, + PYTORCH_SPECIFIC_MAPPINGS, + CAFFE2_SPECIFIC_MAPPINGS, +] diff --git a/tools/amd_build/pyHIPIFY/hipify_python.py b/torch/utils/hipify/hipify_python.py similarity index 97% rename from tools/amd_build/pyHIPIFY/hipify_python.py rename to torch/utils/hipify/hipify_python.py index dace0c70f0c95..fe42e807e6efa 100755 --- a/tools/amd_build/pyHIPIFY/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -32,9 +32,9 @@ import sys import os -from pyHIPIFY import constants -from pyHIPIFY.cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS -from pyHIPIFY.cuda_to_hip_mappings import MATH_TRANSPILATIONS +from . import constants +from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS +from .cuda_to_hip_mappings import MATH_TRANSPILATIONS # Hardcode the PyTorch template map """This dictionary provides the mapping from PyTorch kernel template types @@ -366,9 +366,13 @@ def find_parentheses_group(input_string, start): def replace_math_functions(input_string): - """ FIXME: Temporarily replace std:: invocations of math functions with non-std:: versions to prevent linker errors - NOTE: This can lead to correctness issues when running tests, since the correct version of the math function (exp/expf) might not get called. - Plan is to remove this function once HIP supports std:: math function calls inside device code + """FIXME: Temporarily replace std:: invocations of math functions + with non-std:: versions to prevent linker errors NOTE: This + can lead to correctness issues when running tests, since the + correct version of the math function (exp/expf) might not get + called. Plan is to remove this function once HIP supports + std:: math function calls inside device code + """ output_string = input_string for func in MATH_TRANSPILATIONS: @@ -594,7 +598,7 @@ def pattern(self): CAFFE2_TRIE.add(src) CAFFE2_MAP[src] = dst RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern()) -RE_PYTORCH_PREPROCESSOR = re.compile(r'\b{0}\b'.format(PYTORCH_TRIE.pattern())) +RE_PYTORCH_PREPROCESSOR = re.compile(r'(?<=\W)({0})(?=\W)'.format(PYTORCH_TRIE.pattern())) RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"') RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>') diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py index 73524f8a52998..958f409505af1 100644 --- a/torch/utils/mkldnn.py +++ b/torch/utils/mkldnn.py @@ -17,13 +17,14 @@ def __init__(self, dense_module): @torch.jit.script_method def __getstate__(self): - return (self.weight.to_dense(), self.bias.to_dense()) + return (self.weight.to_dense(), self.bias.to_dense(), self.training) @torch.jit.script_method def __setstate__(self, state): - # type: (Tuple[Tensor, Tensor]) -> None + # type: (Tuple[Tensor, Tensor, bool]) -> None self.weight = state[0].to_mkldnn() self.bias = state[1].to_mkldnn() + self.training = state[2] @torch.jit.script_method def forward(self, x): @@ -55,11 +56,11 @@ def __init__(self, dense_module): @torch.jit.script_method def __getstate__(self): - return (self.weight.to_dense(), self.bias.to_dense()) + return (self.weight.to_dense(), self.bias.to_dense(), self.training) @torch.jit.script_method def __setstate__(self, state): - # type: (Tuple[Tensor, Tensor]) -> None + # type: (Tuple[Tensor, Tensor, bool]) -> None self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight( state[0].to_mkldnn(), self.padding, @@ -67,6 +68,7 @@ def __setstate__(self, state): self.dilation, self.groups) self.bias = state[1].to_mkldnn() + self.training = state[2] @torch.jit.script_method def forward(self, x): @@ -107,15 +109,16 @@ def __getstate__(self): bias = self.bias.to_dense() running_mean = self.running_mean.to_dense() running_var = self.running_var.to_dense() - return (weight, bias, running_mean, running_var) + return (weight, bias, running_mean, running_var, self.training) @torch.jit.script_method def __setstate__(self, state): - # type: (Tuple[Tensor, Tensor, Tensor, Tensor]) -> None + # type: (Tuple[Tensor, Tensor, Tensor, Tensor, bool]) -> None self.weight = state[0].to_mkldnn() self.bias = state[1].to_mkldnn() self.running_mean = state[2].to_mkldnn() self.running_var = state[3].to_mkldnn() + self.training = state[4] @torch.jit.script_method def forward(self, x): diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 8d3065881f1ea..f08b367aae961 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -287,6 +287,7 @@ def add_hparams(self, hparam_dict=None, metric_dict=None): .. image:: _static/img/tensorboard/add_hparam.png :scale: 50 % """ + torch._C._log_api_usage_once("tensorboard.logging.add_hparams") if type(hparam_dict) is not dict or type(metric_dict) is not dict: raise TypeError('hparam_dict and metric_dict should be dictionary.') exp, ssi, sei = hparams(hparam_dict, metric_dict) @@ -323,6 +324,7 @@ def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): :scale: 50 % """ + torch._C._log_api_usage_once("tensorboard.logging.add_scalar") if self._check_caffe2_blob(scalar_value): scalar_value = workspace.FetchBlob(scalar_value) self._get_file_writer().add_summary( @@ -359,6 +361,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None :scale: 50 % """ + torch._C._log_api_usage_once("tensorboard.logging.add_scalars") walltime = time.time() if walltime is None else walltime fw_logdir = self._get_file_writer().get_logdir() for tag, scalar_value in tag_scalar_dict.items(): @@ -402,6 +405,7 @@ def add_histogram(self, tag, values, global_step=None, bins='tensorflow', wallti :scale: 50 % """ + torch._C._log_api_usage_once("tensorboard.logging.add_histogram") if self._check_caffe2_blob(values): values = workspace.FetchBlob(values) if isinstance(bins, six.string_types) and bins == 'tensorflow': @@ -461,6 +465,7 @@ def add_histogram_raw(self, tag, min, max, num, sum, sum_squares, :scale: 50 % """ + torch._C._log_api_usage_once("tensorboard.logging.add_histogram_raw") if len(bucket_limits) != len(bucket_counts): raise ValueError('len(bucket_limits) != len(bucket_counts), see the document.') self._get_file_writer().add_summary( @@ -517,6 +522,7 @@ def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformat :scale: 50 % """ + torch._C._log_api_usage_once("tensorboard.logging.add_image") if self._check_caffe2_blob(img_tensor): img_tensor = workspace.FetchBlob(img_tensor) self._get_file_writer().add_summary( @@ -559,6 +565,7 @@ def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataforma :scale: 30 % """ + torch._C._log_api_usage_once("tensorboard.logging.add_images") if self._check_caffe2_blob(img_tensor): img_tensor = workspace.FetchBlob(img_tensor) self._get_file_writer().add_summary( @@ -579,12 +586,13 @@ def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None, dataformats (string): Image data format specification of the form NCHW, NHWC, CHW, HWC, HW, WH, etc. Shape: - img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformat`` agrument. + img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument. e.g. CHW or HWC box_tensor: (torch.Tensor, numpy.array, or string/blobname): NX4, where N is the number of boxes and each 4 elememts in a row represents (xmin, ymin, xmax, ymax). """ + torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes") if self._check_caffe2_blob(img_tensor): img_tensor = workspace.FetchBlob(img_tensor) if self._check_caffe2_blob(box_tensor): @@ -605,6 +613,7 @@ def add_figure(self, tag, figure, global_step=None, close=True, walltime=None): walltime (float): Optional override default walltime (time.time()) seconds after epoch of event """ + torch._C._log_api_usage_once("tensorboard.logging.add_figure") if isinstance(figure, list): self.add_image(tag, figure_to_image(figure, close), global_step, walltime, dataformats='NCHW') else: @@ -625,6 +634,7 @@ def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): Shape: vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`. """ + torch._C._log_api_usage_once("tensorboard.logging.add_video") self._get_file_writer().add_summary( video(tag, vid_tensor, fps), global_step, walltime) @@ -641,6 +651,7 @@ def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, wallti Shape: snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. """ + torch._C._log_api_usage_once("tensorboard.logging.add_audio") if self._check_caffe2_blob(snd_tensor): snd_tensor = workspace.FetchBlob(snd_tensor) self._get_file_writer().add_summary( @@ -660,15 +671,18 @@ def add_text(self, tag, text_string, global_step=None, walltime=None): writer.add_text('lstm', 'This is an lstm', 0) writer.add_text('rnn', 'This is an rnn', 10) """ + torch._C._log_api_usage_once("tensorboard.logging.add_text") self._get_file_writer().add_summary( text(tag, text_string), global_step, walltime) def add_onnx_graph(self, prototxt): + torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph") self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt)) def add_graph(self, model, input_to_model=None, verbose=False): # prohibit second call? # no, let tensorboard handle it and show its warning message. + torch._C._log_api_usage_once("tensorboard.logging.add_graph") """Add graph data to summary. Args: @@ -742,6 +756,7 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta writer.add_embedding(torch.randn(100, 5), label_img=label_img) writer.add_embedding(torch.randn(100, 5), metadata=meta) """ + torch._C._log_api_usage_once("tensorboard.logging.add_embedding") mat = make_np(mat) if global_step is None: global_step = 0 @@ -800,6 +815,7 @@ def add_pr_curve(self, tag, labels, predictions, global_step=None, writer.close() """ + torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve") labels, predictions = make_np(labels), make_np(predictions) self._get_file_writer().add_summary( pr_curve(tag, labels, predictions, num_thresholds, weights), @@ -831,6 +847,7 @@ def add_pr_curve_raw(self, tag, true_positive_counts, seconds after epoch of event see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md """ + torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve_raw") self._get_file_writer().add_summary( pr_curve_raw(tag, true_positive_counts, @@ -855,6 +872,7 @@ def add_custom_scalars_multilinechart(self, tags, category='default', title='unt writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330']) """ + torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars_multilinechart") layout = {category: {title: ['Multiline', tags]}} self._get_file_writer().add_summary(custom_scalars(layout)) @@ -869,6 +887,7 @@ def add_custom_scalars_marginchart(self, tags, category='default', title='untitl writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006']) """ + torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars_marginchart") assert len(tags) == 3 layout = {category: {title: ['Margin', tags]}} self._get_file_writer().add_summary(custom_scalars(layout)) @@ -892,6 +911,7 @@ def add_custom_scalars(self, layout): writer.add_custom_scalars(layout) """ + torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars") self._get_file_writer().add_summary(custom_scalars(layout)) def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None): @@ -945,6 +965,7 @@ def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, glo writer.close() """ + torch._C._log_api_usage_once("tensorboard.logging.add_mesh") self._get_file_writer().add_summary(mesh(tag, vertices, colors, faces, config_dict), global_step, walltime) def flush(self): diff --git a/version.txt b/version.txt new file mode 100644 index 0000000000000..895d424d5d4b2 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +1.4.0a0