Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6d92e40
enable build for rocm for fp6_llm
lcskrishna Oct 16, 2024
14b3fce
Merge pull request #1 from lcskrishna/cl/rocm-enablement
petrex Oct 17, 2024
f1a22cf
enable tiled layout extension
lcskrishna Oct 23, 2024
0bef6ca
fix build error related to option
Oct 23, 2024
893ae03
require rocm 6.2
Oct 23, 2024
a0d3788
enable tensor tiled layout extension with successful compilation
lcskrishna Oct 24, 2024
e4e654d
enable successful build
lcskrishna Oct 24, 2024
3e2c6a1
clean-up
Oct 29, 2024
c86880e
Merge pull request #3 from lcskrishna/csrikris_enable_tensor_tile
petrex Oct 29, 2024
91d3c75
fix potential memory access issue
Oct 29, 2024
38b7d1c
fix __nv_bfloat162 init
Nov 12, 2024
279f4b3
add comment for MI300x isa
Nov 12, 2024
612ad14
Merge branch 'main' into rocm_enablement_staging
petrex Nov 18, 2024
bbf5a72
fix build for non-rocm
lcskrishna Jan 6, 2025
735570e
Merge pull request #4 from lcskrishna/rocm_enablement
petrex Jan 6, 2025
253c188
Merge branch 'main' into rocm_enablement_staging
petrex Jan 6, 2025
a2f1736
add sparse_marlin kernel to the build
Oct 17, 2024
f817edf
drop .h from conversion
Oct 17, 2024
c9bc1bc
cp_asyc4_pred_zfill() AMD implementation
Oct 17, 2024
16feff4
implement matching mem utility with amd GCN isa
Oct 18, 2024
0b21555
implement mma util with amd gcn isa
Oct 18, 2024
f23b194
enable rocm path
Oct 18, 2024
ecc3927
update copy from global to lds
lcskrishna Oct 22, 2024
a80730b
implement cvta_to_shared()
Oct 23, 2024
d2c7ce4
consolidate code with cvta_to_shared()
Oct 23, 2024
15974c7
Merge branch 'main' into rocm_sparse_marlin
petrex Jan 8, 2025
a4e8c30
lint
Jan 8, 2025
c678cb0
add GPU arch check for MI300x
Jan 9, 2025
08d1cfb
revert change in tensor_core_tile_layout.cu
Jan 9, 2025
b96196b
Merge branch 'main' into rocm_sparse_marlin
petrex Jan 15, 2025
aea9d81
lint
Jan 15, 2025
f18043d
Merge branch 'main' into rocm_sparse_marlin
petrex Feb 25, 2025
8b34390
Merge branch 'main' into rocm_sparse_marlin
petrex Mar 4, 2025
af7027d
fix setup.py conflict
jcaip Mar 4, 2025
15e29f1
fix setup
jcaip Mar 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def get_extensions():
print(
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
)

if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
print(
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
Expand Down Expand Up @@ -324,8 +323,11 @@ def get_extensions():
]
)

# Get base directory and source paths
curdir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(curdir, "torchao", "csrc")

# Collect C++ source files
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
Expand All @@ -339,7 +341,12 @@ def get_extensions():
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
hip_sources += list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)

# Collect CUDA source files if needed
if not IS_ROCM and use_cuda:
sources += cuda_sources
else:
Expand Down
27 changes: 27 additions & 0 deletions torchao/_models/llama/bsr_bench_results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

20250226151422, tok/s=133.29, tok/s_decode=134.40, ttft=0.0118, mem/s=2000.68 GB/s, peak_mem=16.30 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226151926, tok/s=242.08, tok/s_decode=256.68, ttft=0.0464, mem/s=1182.14 GB/s, peak_mem= 6.74 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226152416, tok/s=252.18, tok/s_decode=267.48, ttft=0.0448, mem/s=1229.49 GB/s, peak_mem= 6.73 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226153215, tok/s=204.19, tok/s_decode=213.86, ttft=0.0438, mem/s=1226.65 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226153628, tok/s=180.14, tok/s_decode=187.54, ttft=0.0433, mem/s=1081.56 GB/s, peak_mem= 8.26 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226160622, tok/s=246.20, tok/s_decode=255.21, ttft=0.0281, mem/s= 956.89 GB/s, peak_mem= 5.56 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226160651, tok/s=145.07, tok/s_decode=163.13, ttft=0.1522, mem/s=1461.87 GB/s, peak_mem=22.76 GB, model_size=10.08 GB quant: None, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8

20250226161533, tok/s=250.71, tok/s_decode=254.78, ttft=0.0121, mem/s= 974.38 GB/s, peak_mem= 5.56 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226161913, tok/s=251.19, tok/s_decode=254.95, ttft=0.0112, mem/s= 976.26 GB/s, peak_mem= 5.63 GB, model_size= 3.89 GB quant: sparse-marlin, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226181326, tok/s=134.44, tok/s_decode=140.82, ttft=0.0669, mem/s= 807.62 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226181520, tok/s=138.03, tok/s_decode=164.08, ttft=0.2295, mem/s=1390.97 GB/s, peak_mem=22.74 GB, model_size=10.08 GB quant: None, sparse: semi-structured, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226181738, tok/s=192.65, tok/s_decode=205.62, ttft=0.0649, mem/s=1157.32 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226182045, tok/s=192.75, tok/s_decode=206.24, ttft=0.0673, mem/s=1157.27 GB/s, peak_mem= 8.26 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226182350, tok/s=236.36, tok/s_decode=257.62, ttft=0.0693, mem/s=1154.19 GB/s, peak_mem= 6.74 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226182712, tok/s=231.24, tok/s_decode=250.55, ttft=0.0661, mem/s=1127.37 GB/s, peak_mem= 6.73 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226183255, tok/s=169.58, tok/s_decode=179.82, ttft=0.0665, mem/s=1018.74 GB/s, peak_mem= 8.27 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226183527, tok/s=184.74, tok/s_decode=196.38, ttft=0.0637, mem/s=1109.18 GB/s, peak_mem= 8.26 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226183734, tok/s=232.60, tok/s_decode=252.51, ttft=0.0673, mem/s=1135.85 GB/s, peak_mem= 6.74 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250226183953, tok/s=232.47, tok/s_decode=251.15, ttft=0.0635, mem/s=1133.40 GB/s, peak_mem= 6.73 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250227084325, tok/s=200.72, tok/s_decode=210.91, ttft=0.0475, mem/s=1205.82 GB/s, peak_mem= 8.00 GB, model_size= 6.01 GB quant: None, sparse: bsr-0.8-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250227084708, tok/s=211.76, tok/s_decode=222.43, ttft=0.0447, mem/s=1271.42 GB/s, peak_mem= 7.99 GB, model_size= 6.00 GB quant: None, sparse: bsr-0.8-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.8-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250227085051, tok/s=241.09, tok/s_decode=255.19, ttft=0.0452, mem/s=1177.31 GB/s, peak_mem= 6.47 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-32, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-32 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250227085446, tok/s=247.53, tok/s_decode=262.94, ttft=0.0468, mem/s=1206.80 GB/s, peak_mem= 6.46 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250227090411, tok/s=250.11, tok/s_decode=263.99, ttft=0.0416, mem/s=1219.39 GB/s, peak_mem= 6.46 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20250227091144, tok/s=249.14, tok/s_decode=263.74, ttft=0.0439, mem/s=1214.68 GB/s, peak_mem= 6.46 GB, model_size= 4.88 GB quant: None, sparse: bsr-0.9-64, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --sparsity bsr-0.9-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
12 changes: 12 additions & 0 deletions torchao/_models/llama/bsr_benchmarks.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

# BSR benchmarks
export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B

# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result bsr_bench_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result bsr_bench_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result bsr_bench_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.8-32
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.8-64
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.9-32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result bsr_bench_results.txt --sparsity bsr-0.9-64
2 changes: 1 addition & 1 deletion torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static constexpr int min_thread_n = 128;
static constexpr int tile_size = 16;
static constexpr int max_par = 64;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && !defined(USE_ROCM)

template <const int num_bits, // weight bits
const int threads, // number of threads in a threadblock
Expand Down
Loading
Loading