kernel-builder/skills/rocm-kernels/SKILL.md
Provides guidance for writing and benchmarking optimized Triton kernels for AMD GPUs (MI355X, R9700) on ROCm, targeting HuggingFace diffusers (LTX-Video, SD3, FLUX) and transformers. Core kernels: RMSNorm, RoPE 3D, GEGLU, AdaLN. Includes XCD swizzle, autotune, diffusers integration patterns, and LTX-Video pipeline injection.
npx skillsauth add huggingface/kernels rocm-kernelsInstall this skill globally with one command. Works with Claude Code, Cursor, and Windsurf.
3 of 9 scanners reported clean
Some scanners were skipped, did not run, or reported a non-clean status. Review each row below.
This skill provides patterns and guidance for developing optimized Triton kernels targeting AMD GPUs (MI355X, R9700) on ROCm, for use with HuggingFace diffusers (LTX-Video, SD3, FLUX) and transformers libraries.
Inject optimized kernels into LTX-Video pipeline:
import os
os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1'
os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1'
from diffusers import LTXPipeline
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.to("cuda") # ROCm uses same API via HIP
inject_optimized_kernels(pipe) # BEFORE CPU offloading
pipe.enable_model_cpu_offload()
For a minimal integration example (~150 lines):
python scripts/ltx_kernel_injection_example.py
# All 4 kernels: correctness + performance + bandwidth
python scripts/benchmark_kernels.py
# Single kernel
python scripts/benchmark_kernels.py --kernel rmsnorm
python scripts/benchmark_kernels.py --kernel rope
python scripts/benchmark_kernels.py --kernel geglu
python scripts/benchmark_kernels.py --kernel adaln
# Compare baseline vs Triton vs torch.compile
python scripts/benchmark_e2e.py --mode all
# Quick test
python scripts/benchmark_e2e.py --mode triton --num-frames 9 --steps 5
# Save results for comparison
python scripts/benchmark_e2e.py --mode all --output-json results.json
| Component | Class | Has Weight | Count | Kernel |
|-----------|-------|------------|-------|--------|
| transformer_blocks.*.norm1 | RMSNorm | No (elementwise_affine=False) | 56 | RMSNorm |
| transformer_blocks.*.norm2 | RMSNorm | No | 56 | RMSNorm |
| transformer_blocks.*.attn1.norm_q | torch.nn.RMSNorm | Yes | 28 | RMSNorm |
| transformer_blocks.*.attn1.norm_k | torch.nn.RMSNorm | Yes | 28 | RMSNorm |
| transformer_blocks.*.ff | FeedForward | - | 28 | GELU (not GEGLU!) |
| Rotary position encoding | LTXVideoRotaryPosEmbed | - | 1 | RoPE 3D |
Total RMSNorm modules: 168 (56 with weights, 112 without)
| Kernel | Use Case | Input Layout | Key Challenge |
|--------|----------|-------------|---------------|
| RMSNorm | Normalization | [..., hidden_size] | Weight may be None; 168 instances |
| RoPE 3D | Video position encoding | [batch, t*h*w, heads, head_dim] | 3D → temporal + spatial decomposition |
| GEGLU | Gated activation (SD3/FLUX) | [batch, seq, 2*hidden] → [batch, seq, hidden] | Gate/value split |
| AdaLN | Conditioned normalization (DiT) | norm(x) * weight * (1+scale) + shift | Fused norm + condition |
| GPU | Architecture | Wave Size | LDS/CU | Mem BW | Key Feature | Verified | |-----|-------------|-----------|--------|--------|-------------|:--------:| | MI355X | CDNA3+ (gfx950) | Wave64 | 160 KB | 8 TB/s | 32 XCDs, XCD Swizzle for GEMM | Yes | | R9700 | RDNA4 (gfx1201) | Wave32 | 64 KB | ~608 GB/s | 256B cacheline, inference-focused | Yes |
See MI355X guide | R9700 guide
Use this skill when:
# FORBIDDEN - CUDA only, NOT available on ROCm
tl.libdevice.tanh(x) # Use manual formula below
tl.libdevice.log1p(x) # Use: tl.log(1.0 + x)
tl.math.tanh(x) # Also NOT available on ROCm Triton
# Manual tanh (ONLY reliable method on ROCm):
e2x = tl.exp(2.0 * x)
tanh_x = (e2x - 1.0) / (e2x + 1.0)
# FORBIDDEN - Triton limitations on ROCm
break / continue # Use: tl.where()
min(a, b) / max(a, b) # Use: tl.minimum(a, b) / tl.maximum(a, b)
import os
os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1'
os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1'
Row-wise reduction pattern. 168 instances in LTX-Video, ~5% of total compute.
CRITICAL: Do NOT autotune BLOCK_D. Autotune may pick BLOCK_D < D, causing partial row processing and wrong results. Always compute BLOCK_D = triton.next_power_of_2(D) in the Python wrapper.
@triton.jit
def rmsnorm_kernel(
x_ptr, weight_ptr, out_ptr,
stride_x, D,
eps: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
BLOCK_D: tl.constexpr,
):
row = tl.program_id(0)
offs = tl.arange(0, BLOCK_D)
mask = offs < D
x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32)
variance = tl.sum(x * x, axis=0) / D
rms_inv = tl.rsqrt(variance + eps)
if HAS_WEIGHT:
w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32)
out = x * rms_inv * w
else:
out = x * rms_inv
tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask)
def triton_rmsnorm(x, weight=None, eps=1e-6):
x_2d = x.contiguous().view(-1, x.shape[-1])
out = torch.empty_like(x_2d)
M, D = x_2d.shape
has_weight = weight is not None
if not has_weight:
weight = torch.empty(0, device=x.device)
BLOCK_D = triton.next_power_of_2(D)
num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16)
rmsnorm_kernel[(M,)](
x_2d, weight, out, x_2d.stride(0), D, eps, has_weight,
BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2,
)
return out.view_as(x)
LTX-Video pitfall: Weight may be None!
has_weight = hasattr(module, 'weight') and module.weight is not None
Element-wise pattern. LTX-Video splits head_dim into temporal + spatial components.
CRITICAL: cos/sin have shape [seq_len, head_dim]. When grid flattens batch dimension (batch * seq_len), use pid_s % seq_len to index cos/sin, otherwise batch > 1 causes OOB GPU crash.
@triton.jit
def rope_3d_kernel(
qk_ptr, cos_ptr, sin_ptr, out_ptr,
seq_len, num_heads, head_dim,
stride_s, stride_h, stride_d,
BLOCK_HD: tl.constexpr,
):
pid_s = tl.program_id(0) # batch * seq_len
pid_h = tl.program_id(1) # head index
half_dim = head_dim // 2
offs = tl.arange(0, BLOCK_HD)
mask = offs < half_dim
base = pid_s * stride_s + pid_h * stride_h
x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32)
x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32)
seq_idx = pid_s % seq_len # wrap for batch > 1
cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32)
sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32)
out0 = x0 * cos_val - x1 * sin_val
out1 = x0 * sin_val + x1 * cos_val
tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask)
tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask)
def triton_rope_3d(qk, cos, sin):
qk = qk.contiguous()
out = torch.empty_like(qk)
batch, seq_len, num_heads, head_dim = qk.shape
qk_flat = qk.view(batch * seq_len, num_heads, head_dim)
out_flat = out.view(batch * seq_len, num_heads, head_dim)
BLOCK_HD = triton.next_power_of_2(head_dim // 2)
num_warps = 4 if BLOCK_HD <= 64 else 8
rope_3d_kernel[(batch * seq_len, num_heads)](
qk_flat, cos, sin, out_flat,
seq_len, num_heads, head_dim,
qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2),
BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2,
)
return out
Element-wise gated activation. Input [batch, seq, 2*hidden] → Output [batch, seq, hidden].
Same BLOCK_SIZE rule: compute dynamically, do NOT autotune.
@triton.jit
def geglu_kernel(
input_ptr, output_ptr,
stride_in, stride_out, hidden_size,
BLOCK_H: tl.constexpr,
):
row = tl.program_id(0)
offs = tl.arange(0, BLOCK_H)
mask = offs < hidden_size
gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32)
value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32)
# GELU approx — manual tanh (tl.math.tanh NOT available on ROCm)
k = 0.7978845608028654 # sqrt(2/pi)
tanh_arg = k * (gate + 0.044715 * gate * gate * gate)
e2x = tl.exp(2.0 * tanh_arg)
tanh_val = (e2x - 1.0) / (e2x + 1.0)
gate_gelu = 0.5 * gate * (1.0 + tanh_val)
result = gate_gelu * value
tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask)
def triton_geglu(x):
x = x.contiguous()
*batch_dims, double_h = x.shape
hidden_size = double_h // 2
x_2d = x.view(-1, double_h)
M = x_2d.shape[0]
out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype)
BLOCK_H = triton.next_power_of_2(hidden_size)
num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16)
geglu_kernel[(M,)](
x_2d, out, x_2d.stride(0), out.stride(0), hidden_size,
BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2,
)
return out.view(*batch_dims, hidden_size)
Warning: LTX-Video uses GELU, NOT GEGLU. GEGLU is for SD3/FLUX.
Fused normalization + conditioning: norm(x) * weight * (1 + scale) + shift
Same BLOCK_D rule: compute dynamically.
@triton.jit
def adaln_kernel(
x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr,
stride_x, stride_cond, D,
eps: tl.constexpr,
BLOCK_D: tl.constexpr,
):
row = tl.program_id(0)
offs = tl.arange(0, BLOCK_D)
mask = offs < D
x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32)
variance = tl.sum(x * x, axis=0) / D
rms_inv = tl.rsqrt(variance + eps)
x_norm = x * rms_inv
w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32)
scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32)
shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32)
out = x_norm * w * (1.0 + scale) + shift
tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask)
def triton_adaln(x, weight, scale, shift, eps=1e-6):
x_flat = x.contiguous().view(-1, x.shape[-1])
scale_flat = scale.contiguous().view(-1, x.shape[-1])
shift_flat = shift.contiguous().view(-1, x.shape[-1])
out = torch.empty_like(x_flat)
M, D = x_flat.shape
BLOCK_D = triton.next_power_of_2(D)
num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16)
adaln_kernel[(M,)](
x_flat, weight, scale_flat, shift_flat, out,
x_flat.stride(0), scale_flat.stride(0), D, eps,
BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2,
)
return out.view_as(x)
See diffusers-integration.md for the complete guide.
def patch_rmsnorm_modules(model):
"""Patch all RMSNorm modules to use custom Triton kernel."""
for name, module in model.named_modules():
if type(module).__name__ == 'RMSNorm':
eps = getattr(module, 'eps', 1e-6)
has_weight = hasattr(module, 'weight') and module.weight is not None
if has_weight:
def make_forward(mod, epsilon):
def forward(x):
return triton_rmsnorm(x, mod.weight, eps=epsilon)
return forward
module.forward = make_forward(module, eps)
else:
def make_forward(epsilon):
def forward(x):
w = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
return triton_rmsnorm(x, w, eps=epsilon)
return forward
module.forward = make_forward(eps)
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.to("cuda")
patch_rmsnorm_modules(pipe.transformer)
pipe.enable_model_cpu_offload()
elementwise_affine=Falsetype(module).__name__ not isinstance()inject_kernels() then enable_model_cpu_offload()| Kernel | Avg Speedup | Best Config Speedup | Status | |--------|:-----------:|:-------------------:|:------:| | RMSNorm | 1.71x | 2.44x ([4×4096×3072]) | PASS | | RoPE 3D | 1.21x | 1.52x ([2×4096×16×128]) | PASS | | GEGLU | 1.43x | 2.13x ([4×4096×8192]) | PASS | | AdaLN | 2.22x | 2.77x ([4×4096×3072]) | PASS |
RMSNorm bandwidth utilization: 3554 GB/s (MI355X theoretical: 8 TB/s, ~44%).
| Mode | Time (s) | Per Step (s) | Peak Mem (GB) | Speedup | |------|:--------:|:------------:|:-------------:|:-------:| | baseline | 1.20 | 0.040 | 18.58 | 1.00x | | triton | 0.98 | 0.033 | 18.58 | 1.22x | | torch.compile | 0.78 | 0.026 | 18.58 | 1.54x |
Key finding: MI355X Triton E2E speedup (22%) is significantly higher than H100 CUDA reference (6%), because MI355X's default PyTorch RMSNorm path has more room for optimization.
| Kernel | Avg Speedup | Best Config Speedup | Status | |--------|:-----------:|:-------------------:|:------:| | RMSNorm | 2.90x | 3.97x ([1×8192×2048]) | PASS | | RoPE 3D | 2.09x | 2.38x ([1×1024×16×64]) | PASS | | GEGLU | 1.69x | 1.93x ([2×1024×8192]) | PASS | | AdaLN | 3.00x | 3.67x ([4×4096×3072]) | PASS |
RMSNorm bandwidth utilization: 483 GB/s (R9700 theoretical: ~608 GB/s, ~79%).
R9700 speedups are higher than MI355X because PyTorch's default RDNA4 backend is less mature, leaving more room for Triton optimization. The bandwidth utilization (79%) is also significantly better than MI355X (44%).
| Mode | Time (s) | Per Step (s) | Peak Mem (GB) | Speedup | |------|:--------:|:------------:|:-------------:|:-------:| | baseline (mean of 3) | 6.91 | 0.231 | 18.58 | 1.00x | | triton (mean of 3) | 6.10 | 0.203 | 18.58 | 1.13x | | torch.compile (single run) | 5.05 | 0.168 | 18.58 | 1.37x |
Reviewer-facing benchmark files for this comparison live in examples/ltx-video-benchmark/, including:
gen_time_s, time_per_step_s, peak_memory_gb, and speedupexamples/ltx-video-benchmark/benchmark_results.jsonexamples/ltx-video-benchmark/trace/opencode_live/results.jsonexamples/ltx-video-benchmark/trace/opencode_live/opencode_trace_result.json| Test | Result | |------|--------| | Transformers injection (TinyLlama 1.1B) | PASS — 45 RMSNorm patched, 99.9 tokens/s | | HuggingFace Kernels Hub integration | PASS — Hub kernel loads and runs on ROCm | | Local Triton vs Hub kernel (small shape) | Local 5.92x vs Hub 1.27x (lower launch overhead) | | Local Triton vs Hub kernel (large shape) | Local 3.59x vs Hub 3.57x (comparable) | | num_warps sweep (2/4/8/16/32) | Default heuristic (4/8/16) is near-optimal; nw=32 always worst | | rocprof kernel fusion analysis | Triton fuses 4 PyTorch kernels (pow+mean+rsqrt+mul) into 1 |
| Shape | Custom (ms) | PyTorch (ms) | Speedup | |:---|:---:|:---:|:---:| | [1×1024×2048] | 0.019 | 0.065 | 3.37x | | [2×4096×3072] | 0.087 | 0.208 | 2.41x |
H100 E2E: ~6% (RMSNorm is ~5% of total compute).
| Kernel | MI355X | R9700 | Target | Priority | |--------|:------:|:-----:|:------:|:--------:| | RMSNorm | 1.71x | 2.90x | >3x (R9700) | P0 — MI355X bandwidth util (44%→60%+) | | AdaLN | 2.22x | 3.00x | >3.5x (R9700) | P1 — already strong on both | | GEGLU | 1.43x | 1.69x | >2x | P1 — tanh overhead | | RoPE 3D | 1.21x | 2.09x | >2.5x (R9700) | P2 — small head_dim launch overhead |
| Issue | Symptom | Fix |
|-------|---------|-----|
| Autotune BLOCK_D | Wrong results (max_abs 4-8+) | Never autotune BLOCK_D. Use triton.next_power_of_2(D) |
| RoPE batch OOB | GPU crash (Memory access fault) | Use pid_s % seq_len for cos/sin indexing |
| tl.libdevice | Not found on AMD | Use manual math formulas |
| tl.tanh / tl.math.tanh | Not on ROCm | Manual: e2x=exp(2x); (e2x-1)/(e2x+1) |
| Python min/max | Runtime error | tl.minimum()/tl.maximum() |
| LDS overflow | HIP OOM | Reduce num_stages to 2 |
| Weight is None | AttributeError | Check elementwise_affine |
| isinstance() miss | RMSNorm not patched | Use type(module).__name__ |
See troubleshooting.md for all common issues.
rocprof --stats python your_kernel.py
rocprofv3 -i metrics.txt python your_kernel.py
rocm-bandwidth-test
rocminfo | grep -E "Name|Compute Unit|Wavefront"
get_kernel)development
Provides guidance for writing, optimizing, and benchmarking Triton kernels for Intel XPU GPUs (Battlemage/Arc Pro B50) using the Xe-Forge optimization framework. Includes an LLM-driven trial-loop workflow (analyze, validate, benchmark, profile, finalize), XPU-specific patterns (tensor descriptors, GRF mode, tile swizzling), KernelBench fused kernels, and Flash Attention.
development
Provides guidance for writing and benchmarking optimized CUDA kernels for NVIDIA GPUs (H100, A100, T4) targeting HuggingFace diffusers and transformers libraries. Supports models like LTX-Video, Stable Diffusion, LLaMA, Mistral, and Qwen. Includes integration with HuggingFace Kernels Hub (get_kernel) for loading pre-compiled kernels. Includes benchmarking scripts to compare kernel performance against baseline implementations.
devops
Local text-to-speech via sherpa-onnx (offline, no cloud)
devops
Feishu cloud storage file management. Activate when user mentions cloud space, folders, drive.