Skip to content

torch.compiled custom Triton kernels can output incorrect results #136550

@mobicham

Description

@mobicham

When torch.compiling a custom Triton kernel with fullgraph=True, the kernel starts outputting incorrect results.
Here's an example to reproduce it:

#pip install git+https://github.com/mobiusml/gemlite/ -> Need to comment out the triton.autotune extra params like prune_configs_by since torch.compile doesn't support that yet
###############################################################
from gemlite.core import GemLiteLinearTriton, DType
import torch 

def check_valid(x, W, quant_linear, tol=1e-3):
    y_ref = torch.matmul(x, W.T)
    y_q   = quant_linear(x)
    try:
        assert (y_ref - y_q).abs().mean() < tol
    except:
        raise Exception('Assertion Failed')

W_nbits, group_size = 4, 128 
in_features, out_features = 4096, 4096

gemlite_linear = GemLiteLinearTriton(W_nbits, group_size=group_size, 
                                    in_features=in_features, 
                                    out_features=out_features, 
                                    input_dtype=DType.FP16, 
                                    output_dtype=DType.FP16, 
                                    acc_dtype=DType.FP16)



###############################################################
device = 'cuda:0'
compute_dtype = torch.float16

orig_shape = (out_features, in_features)

W_q    = torch.randint(0, 2**W_nbits, (out_features, in_features), dtype=torch.uint8, device=device).to(torch.uint8)
N      = in_features * out_features // group_size
scales = torch.randn((N,), dtype=compute_dtype, device=device).abs()/500.
zeros  = torch.randint(0, 2**W_nbits - 1, (N,), dtype=compute_dtype, device=device)
W      = ((W_q.reshape([-1, group_size]) - zeros.view((N, 1))) * scales.view((N, 1))).reshape(orig_shape)

gemlite_linear.pack(W_q, scales, zeros, None);
###############################################################

for batch_size in [16]:
    x = torch.randn((batch_size, in_features), dtype=gemlite_linear.compute_dtype, device='cuda:0')/10.
    check_valid(x, W, gemlite_linear)
#OK

gemlite_linear.forward = torch.compile(gemlite_linear.forward, fullgraph=True)

for batch_size in [16]:
    x = torch.randn((batch_size, in_features), dtype=gemlite_linear.compute_dtype, device='cuda:0')/10.
    check_valid(x, W, gemlite_linear)

#Incorrect

During the torch.compile process, it says this:

W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0] torch._inductor.codecache.BypassFxGraphCache: Can't cache HigherOrderOper
ators.  
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0]
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0] During handling of the above exception, another exception occurred:
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0]
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0] Traceback (most recent call last):
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_higher_order_ops/t
riton_kernel_wrap.py", line 485, in identify_mutated_tensors
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0]     ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_higher_order_ops/t
riton_kernel_wrap.py", line 140, in generate_ttir
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0]     raise ValueError("Incorrect number of arguments passed to kernel")
W0924 17:39:26.298000 22657 torch/_higher_order_ops/triton_kernel_wrap.py:506] [0/0] ValueError: Incorrect number of arguments passed to kernel

Here's one of the kernels that break: https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py#L52-L76

Is this something related to the input arguments?

Thank you very much in advance!

Versions

Collecting environment information...
PyTorch version: 2.5.0.dev20240905+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.35

Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 550.78
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-254
Off-line CPU(s) list: 255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7B12 64-Core Processor
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 0
Frequency boost: enabled
CPU max MHz: 2250.0000
CPU min MHz: 0.0000
BogoMIPS: 4499.65
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca sme sev sev_es
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 512 MiB (32 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-254
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.11.0
[pip3] pytorch-triton==3.0.0+dedb7bdf33
[pip3] torch==2.5.0.dev20240905+cu121
[pip3] torchaudio==2.3.1
[pip3] torchelastic==0.2.2
[pip3] torchvision==0.19.0
[pip3] triton==3.0.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py310h5eee18b_1
[conda] mkl_fft 1.3.8 py310h5eee18b_0
[conda] mkl_random 1.2.4 py310hdb19cb5_0
[conda] numpy 1.26.4 py310h5f9d8c6_0
[conda] numpy-base 1.26.4 py310hb5e798b_0
[conda] optree 0.11.0 pypi_0 pypi
[conda] pytorch-cuda 12.1 ha16c6d3_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-triton 3.0.0+dedb7bdf33 pypi_0 pypi
[conda] torch 2.5.0.dev20240905+cu121 pypi_0 pypi
[conda] torchaudio 2.3.1 py310_cu121 pytorch
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torchvision 0.19.0 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @oulgen @aakhundov @davidberard98

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: user tritonrelated to ability to directly torch.compile triton kernelsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions