Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d3b84cb
varlen maba
oleksost Aug 13, 2025
79a4565
requirement
oleksost Aug 13, 2025
1657a1b
docker
oleksost Aug 13, 2025
2b171eb
test varlen mamba
oleksost Aug 15, 2025
115c1ec
wip
oleksost Aug 19, 2025
37d3be8
cleanup
oleksost Aug 19, 2025
1b20268
Merge branch 'mamba_varlen' into tp_mamba2
oleksost Aug 19, 2025
35c6f20
wip
oleksost Aug 20, 2025
17f86fd
wip
oleksost Aug 20, 2025
adb0666
wip
oleksost Aug 20, 2025
bc25e74
mamba2 nemotron h tp
oleksost Aug 21, 2025
7c5fb0a
modeling
oleksost Aug 22, 2025
9cef978
convertion + MIL init
oleksost Aug 25, 2025
662e9ef
convertion
oleksost Aug 25, 2025
f78055c
undo requirement varlen for m2 testing
oleksost Aug 25, 2025
eb8a54e
varlen
oleksost Aug 25, 2025
33281d5
wip
oleksost Aug 25, 2025
2a5d0f9
rms norm
oleksost Aug 26, 2025
7a047b4
clean up
oleksost Aug 26, 2025
7a09387
TP RMS norm
oleksost Sep 16, 2025
03a7ac2
TP RMS norm
oleksost Sep 16, 2025
bd85e85
Merge branch 'hybrid_dev' into tp_mamba2
oleksost Sep 16, 2025
a3cb3e0
nvm
oleksost Sep 16, 2025
826f2f0
nvm
oleksost Sep 17, 2025
c9c412e
nvm
oleksost Sep 17, 2025
33e9597
wip
oleksost Sep 18, 2025
7f3bfe9
modelling mamba2
oleksost Sep 22, 2025
bad4c3b
wip
oleksost Sep 22, 2025
fd617c8
mamba2 with rms norm not per head
oleksost Sep 22, 2025
799ec67
per head norm
oleksost Sep 22, 2025
85afd22
per head norm
oleksost Sep 22, 2025
157ce73
multihead norm
oleksost Sep 22, 2025
dfb75ae
norm per layer
oleksost Sep 23, 2025
4003f37
nvm
oleksost Sep 23, 2025
70a04e3
clean
oleksost Sep 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 338 additions & 0 deletions fast_llm/functional/triton/normalization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import typing

import torch
Expand Down Expand Up @@ -306,3 +307,340 @@ def rms_norm(input_: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Te
input_dtype = input_.dtype
input_ = input_.to(torch.float32)
return (weight * input_ * torch.rsqrt(input_.pow(2).mean(dim=-1, keepdim=True) + eps)).to(dtype=input_dtype)


# from mamba2
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
group = tl.program_id(1)
X += row * stride_x_row + group * N
Y += row * stride_y_row + group * N
if HAS_Z:
Z += row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y + cols, y, mask=mask)


def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd


@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
@triton.jit
def _layer_norm_bwd_kernel(
X, # pointer to the input
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Y, # pointer to the output to be recomputed
DY, # pointer to the output gradient
DX, # pointer to the input gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
DZ, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_z_row,
stride_y_row,
stride_dy_row,
stride_dx_row,
stride_dz_row,
stride_dw_row,
stride_db_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
rows_per_program,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0)
group = tl.program_id(1)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, BLOCK_N)
mask = cols < N
X += row_start * stride_x_row + group * N
if HAS_Z:
Z += row_start * stride_z_row + group * N
DZ += row_start * stride_dz_row + group * N
DY += row_start * stride_dy_row + group * N
DX += row_start * stride_dx_row + group * N
if RECOMPUTE_OUTPUT:
Y += row_start * stride_y_row + group * N
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
B += group * N
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
if HAS_BIAS:
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
row_end = min((row_block_id + 1) * rows_per_program, M)
for row in range(row_start, row_end):
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.load(Mean + row)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32)
x_og = x
x = x_og * z * tl.sigmoid(z)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
xhat = tl.where(mask, xhat, 0.0)
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32)
z_sigmoid = tl.sigmoid(z)
y = xhat * w + b if HAS_BIAS else xhat * w
if RECOMPUTE_OUTPUT:
tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
tl.store(DZ + cols, dz, mask=mask)
dy *= z * z_sigmoid
else:
if RECOMPUTE_OUTPUT:
y = xhat * w + b if HAS_BIAS else xhat * w
tl.store(Y + cols, y, mask=mask)
wdy = w * dy
c1 = tl.sum(xhat * wdy, axis=0) / N
if not IS_RMS_NORM:
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
else:
dx = (wdy - xhat * c1) * rstd
dw += dy * xhat
if HAS_BIAS:
db += dy
if HAS_Z and not NORM_BEFORE_GATE:
z_sigmoid = tl.sigmoid(z)
dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
tl.store(DZ + cols, dz, mask=mask)
dx *= z * z_sigmoid
# Write dx
tl.store(DX + cols, dx, mask=mask)

X += stride_x_row
if HAS_Z:
Z += stride_z_row
DZ += stride_dz_row
if RECOMPUTE_OUTPUT:
Y += stride_y_row
DY += stride_dy_row
DX += stride_dx_row
tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
if HAS_BIAS:
tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)


def _layer_norm_bwd(
dy,
x,
weight,
bias,
eps,
mean,
rstd,
z=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
recompute_output=False,
dz=None,
out=None,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
assert dy.stride(-1) == 1
assert dy.shape == (M, N)
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
dx = torch.empty_like(x)
if dz is not None:
assert z is not None
assert dz.shape == z.shape
assert dz.stride(-1) == 1
else:
dz = torch.empty_like(z) if z is not None else None
if recompute_output:
if out is None:
out = torch.empty_like(x)
assert out.shape == x.shape

# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
# If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
# would limit the occupancy.
nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
_dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
_db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
rows_per_program = math.ceil(M / nrow_groups)
grid = (nrow_groups, ngroups)
with torch.cuda.device(x.device.index):
_layer_norm_bwd_kernel[grid](
x,
weight,
bias,
z,
out if recompute_output else None,
dy,
dx,
_dw,
_db,
dz,
mean,
rstd,
x.stride(0),
z.stride(0) if z is not None else 0,
0 if not recompute_output else out.stride(0),
dy.stride(0),
dx.stride(0),
dz.stride(0) if dz is not None else 0,
_dw.stride(0),
_db.stride(0) if _db is not None else 0,
M,
group_size,
eps,
rows_per_program,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
dw = _dw.sum(0).to(weight.dtype)
db = _db.sum(0).to(bias.dtype) if bias is not None else None
return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
10 changes: 10 additions & 0 deletions fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ class NormalizationImplementation(str, enum.Enum):
triton = "triton"


class TPRMSNormImplementation(str, enum.Enum):
"""
An enum for the available implementations of rms norm.
"""

fused_redtensor = "fused_redtensor"
autograd_redstats = "autograd_redstats"
torch_comp_redstats = "torch_comp_redstats"


@config_class(registry=True)
class NormalizationConfig(BaseModelConfig):
pass
Expand Down
Loading