Skip to content
This repository was archived by the owner on Aug 15, 2025. It is now read-only.
This repository was archived by the owner on Aug 15, 2025. It is now read-only.

Nightly pip wheels incompatible with pytorch-triton workflow #1318

@ptrblck

Description

@ptrblck

Description

Based on pytorch/pytorch#94818 (comment) ptxas should be bundled with "triton" (I assume it should ship in the pytorch-triton wheel), which does not seem to be the case using the latest nightly binary.

Setup info

Collecting environment information...
PyTorch version: 2.0.0.dev20230218+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.8.15 | packaged by conda-forge | (default, Nov 22 2022, 08:49:35)  [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 525.60.11
cuDNN version: Probably one of the following:
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.5.0
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

[removing CPU info as it's not interesting]

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.1
[pip3] numpydoc==1.5.0
[pip3] pytorch-triton==2.0.0+c8bfe3f548
[pip3] torch==2.0.0.dev20230218+cu118
[pip3] torchaudio==2.0.0.dev20230218+cu118
[pip3] torchvision==0.15.0.dev20230218+cu118
[conda] numpy                     1.24.1                   pypi_0    pypi
[conda] numpydoc                  1.5.0                    pypi_0    pypi
[conda] pytorch-triton            2.0.0+c8bfe3f548          pypi_0    pypi
[conda] torch                     2.0.0.dev20230218+cu118          pypi_0    pypi
[conda] torchaudio                2.0.0.dev20230218+cu118          pypi_0    pypi
[conda] torchvision               0.15.0.dev20230218+cu118          pypi_0    pypi

ptxas in pytorch-triton

The latest nightly tags pytorch-triton==2.0.0+c8bfe3f548 which is correct according to .github/ci_commit_pins/triton.txt.

pytorch-triton searches the ptxas binary using a specified TRITON_PTXAS_PATH or depends on triton/third_party/cuda/bin/ptxas as seen in: https://github.com/openai/triton/blob/c8bfe3f548b164f745ada620a560f87f41ab8465/python/triton/compiler.py#L1066-L1067.

It seems however, triton/third_party/cuda does not contain the expected bin folder as seen in:
https://github.com/openai/triton/tree/c8bfe3f548b164f745ada620a560f87f41ab8465/python/triton/third_party/cuda

Example code snippet with failure

Using a simple RN50 with torch.compile:

import torch
import torch.nn as nn
import torchvision.models as models

model = models.resnet50().cuda()
model = torch.compile(model)

x = torch.randn(1, 3, 224, 224).cuda()
out = model(x)
print(out.shape)

fails with:

  File "/usr/local/lib/python3.8/dist-packages/triton/compiler.py", line 1078, in path_to_ptxas
    raise RuntimeError("Cannot find ptxas")
RuntimeError: Cannot find ptxas

Workaround with TRTON_PTXAS_PATH

Setting the TRITON_PTXAS_PATH to a valid ptxas location (from a locally installed CUDA toolkit) fails either with this error on bare metal:

TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas python resnet_compile.py 
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -lcuda
collect2: error: ld returned 1 exit status
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/codecache.py", line 560, in _worker_compile
    kernel.precompile(warm_cache_only_with_cc=cc)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/triton_ops/autotune.py", line 69, in precompile
    self.launchers = [
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/triton_ops/autotune.py", line 70, in <listcomp>
    self._precompile_config(c, warm_cache_only_with_cc)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/torch/_inductor/triton_ops/autotune.py", line 83, in _precompile_config
    triton.compile(
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/triton/compiler.py", line 1586, in compile
    so_path = make_stub(name, signature, constants)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/triton/compiler.py", line 1475, in make_stub
    so = _build(name, src_path, tmpdir)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/site-packages/triton/compiler.py", line 1390, in _build
    ret = subprocess.check_call(cc_cmd)
  File "/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/lib/python3.8/subprocess.py", line 364, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp44lexpnp/main.c', '-O3', '-I/usr/local/cuda/include', '-I/home/pbialecki/miniforge3/envs/nightly_pip_cuda118/include/python3.8', '-I/tmp/tmp44lexpnp', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmp44lexpnp/triton_.cpython-38-x86_64-linux-gnu.so']' returned non-zero exit status 1.

or with this error in a docker container:

TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas python tmp.py 
LLVM ERROR: Can't find libdevice at neither /usr/local/lib/python3.8/dist-packages/triton/third_party/cuda/lib/libdevice.10.bc nor /tmp/tmp0fe53ca9/triton/python/triton/third_party/cuda/lib/libdevice.10.bc
LLVM ERROR: Can't find libdevice at neither /usr/local/lib/python3.8/dist-packages/triton/third_party/cuda/lib/libdevice.10.bc nor /tmp/tmp0fe53ca9/triton/python/triton/third_party/cuda/lib/libdevice.10.bc

Missing third_party dependency

The second error is strange, as it claims that not even the libdevice.10.bc file can be found and indeed it seems the entire third_party folder is missing:

ls /usr/local/lib/python3.8/dist-packages/triton/
_C  __init__.py  __pycache__  compiler.py  impl  language  ops  runtime  testing.py  tools  utils.py

To double check it, I've downloaded the wheel manually via:

wet https://download.pytorch.org/whl/nightly/pytorch_triton-2.0.0%2Bc8bfe3f548-cp38-cp38-linux_x86_64.whl

and after unzipping it I also cannot find any ptxas, libdevice*, or third_party.

Possible fixes

My best guess right now would be:

  • openai/triton needs to be updated with the bin/ptxas file as it's missing in third_party/cuda
  • the pytoch-triton wheel needs to be rebuilt as it's missing the entire third_party folder
  • I don't know how to understand the /usr/bin/ld: cannot find -lcuda error as gcc -lcuda tries to use my local CUDA toolkit (should this be the case?).

Side note for the last point: my local CUDA toolkit can properly build PyTorch from source and a simple CUDA driver API example with -lcuda, so even if we expect a dependency on a locally installed CUDA toolkit, I'm still unsure why it's failing.

Let me know, if I'm missing something.

CC @malfet @atalman @ngimel

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions