skills/model-config-guide/SKILL.md
Create GPU config files to support existing MaxText model definitions on AMD GPU clusters. Use when the user wants to add a model, create a config, support a new model, or asks about model configs, parallelism, batch size, OOM, quantization, or .gpu.yml files.
npx skillsauth add AMD-AGI/maxtext-slurm model-config-guideInstall this skill globally with one command. Works with Claude Code, Cursor, and Windsurf.
3 of 9 scanners reported clean
Some scanners were skipped, did not run, or reported a non-clean status. Review each row below.
Create a .gpu.yml config to run a MaxText model on AMD GPUs. Most configs reference models already defined in MaxText's configs/models/, but configs can also define or tweak model architectures inline when needed.
Batch size, memory budget, quantization support, and peak FLOPS all depend on the GPU. Check what's available:
rocm-smi --showproductname # AMD
nvidia-smi -L # NVIDIA
Key differences by GPU (see utils/monkey_patch_maxtext.py for the full table):
| GPU | VRAM | FP8 TFLOPS | BF16 TFLOPS | |-----|------|------------|-------------| | MI355X (CDNA 4) | 288 GB | 5000 | 2500 | | MI300X (CDNA 3) | 192 GB | 2614 | 1307 |
More VRAM means larger per_device_batch_size and max_target_length. FP8 support enables quantization: "fp8" for memory savings and higher throughput on models that benefit from it.
ls /workspace/maxtext/src/MaxText/configs/models/
If the model is there, use a model-name reference (simplest path). If not — or if you need to customize architecture parameters — define the model inline (see Inline models below).
For model-name configs, read the MaxText model file to understand the architecture (dense vs MoE, size, head counts, etc.) — this determines parallelism and memory decisions:
cat /workspace/maxtext/src/MaxText/configs/models/<model>.yml
Create configs/<model-name>.gpu.yml. The filename (minus .gpu.yml) becomes the model's short name for submit.sh and run_local.sh. Choose a name that is:
submit.sh 70b must match exactly one file)Use the template below, filling in decisions from the following steps.
Key constraint: at most one -1 (auto-sharded) axis per group (DCN and ICI each). The product of all explicit axes in ICI must equal 8 (GPUs per node).
Standard pattern — shard data across nodes, FSDP within each node:
dcn_data_parallelism: -1 # auto-sharded across nodes
ici_fsdp_parallelism: 8 # shard weights across 8 GPUs per node
For very large dense models (405B+), FSDP on both levels:
dcn_fsdp_parallelism: -1 # auto-sharded
ici_fsdp_parallelism: -1 # auto-sharded
Standard pattern — FSDP across nodes, expert parallelism within each node:
dcn_fsdp_parallelism: 8 # or -1 for auto
ici_expert_parallelism: 8 # spread experts across 8 GPUs per node
All other axes default to 1. Always set them explicitly for clarity (see template).
Prefer FSDP over pipeline parallelism. Pipeline parallelism in MaxText has known issues — likely related to JAX's SPMD model — and is harder to debug and tune. Use it only as a last resort when FSDP alone cannot fit the model. If needed, set num_layers_per_pipeline_stage (layers must divide evenly) and num_pipeline_microbatches, and disable sparse_matmul:
dcn_pipeline_parallelism: 4
num_layers_per_pipeline_stage: 7 # total_layers / pipeline_stages
num_pipeline_microbatches: 32
sparse_matmul: false # required with pipeline parallelism
dcn_sequence_parallelism / dcn_tensor_parallelism — never recommended (comment: # never recommended)Start from a known-working config for a similar-sized model and adjust. These are typical values on MI355X (288 GB); reduce for GPUs with less VRAM:
| Model size | Typical per_device_batch_size | Typical max_target_length |
|------------|--------------------------------|----------------------------|
| < 20B | 32–64 | 2048–4096 |
| 70B | 5–8 | 4096–8192 |
| 300B+ MoE | 10–22 | 4096–8192 |
| 400B+ dense | 3–5 | 8192 |
If OOM, reduce per_device_batch_size first, then max_target_length. If the model barely fits at batch size 1, increase XLA_PYTHON_CLIENT_MEM_FRACTION (default .85 in train_env.sh) — this controls what fraction of GPU memory JAX pre-allocates. Setting it higher (e.g., .90, .93) gives more room for model weights and activations, but too high risks RCCL allocation failures. Note: XLA may inflate allocations when more memory is available, so increasing the fraction doesn't always yield proportional headroom.
For per-run testing, pass it via CLI:
submit.sh 70b -N 1 -- _env_XLA_PYTHON_CLIENT_MEM_FRACTION=.90
For permanent per-model overrides, create a configs/<model>.env.sh file (sourced after train_env.sh, before CLI _env_ overrides):
# configs/my-large-model.env.sh
export XLA_PYTHON_CLIENT_MEM_FRACTION=.93
If the model is not yet supported by the image's default MaxText branch, set MAXTEXT_PATCH_BRANCH in the .env.sh to check out a branch that adds support:
# configs/new-model.env.sh
export MAXTEXT_PATCH_BRANCH="feature/new-model-support"
Default for most models:
dtype: bfloat16
weight_dtype: bfloat16
For large models where memory is tight, add FP8 quantization:
quantization: "fp8" # or "nanoo_fp8"
quantize_kvcache: false
kv_quant_axis: heads_and_dkv
kv_quant_dtype: int8
checkpoint_is_quantized: false # true only when restoring from quantized checkpoint
megablox: false
sparse_matmul: false # cannot use with pipeline parallelism
capacity_factor: 1.25 # 1.0 for exact routing, >1.0 for overflow buffer
Add expert_balance: true if the model benefits from load-balanced routing.
run_local.sh <short-name> -- steps=5 dataset_type=synthetic
Watch for:
per_device_batch_sizesubmit.sh <short-name> -N <nodes> -- steps=15 dataset_type=synthetic
# ── Run Config ───────────────────────────────────────────────────────────────
base_config: "base.yml"
run_name: "<model>_train_test"
hardware: "gpu"
model_name: "<model>"
# ── Training ─────────────────────────────────────────────────────────────────
steps: 15
# ── Parallelism ──────────────────────────────────────────────────────────────
# At most one DCN axis can be unspecified (-1)
dcn_data_parallelism: -1 # auto-sharded
dcn_fsdp_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
# At most one ICI axis can be unspecified (-1)
ici_data_parallelism: 1
ici_fsdp_parallelism: 8
ici_sequence_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
ici_expert_parallelism: 1
# ── Batch & Sequence ─────────────────────────────────────────────────────────
per_device_batch_size: 8
max_target_length: 4096
packing: true
max_segments_per_seq: 32
# ── Attention & Compute ──────────────────────────────────────────────────────
attention: "cudnn_flash_te"
remat_policy: "full"
use_iota_embed: True
scan_layers: True
logits_dot_in_fp32: False
shardy: false
# ── Profiler ─────────────────────────────────────────────────────────────────
profiler: "" #"xplane"
upload_all_profiler_results: true
skip_first_n_steps_for_profiler: 3
profiler_steps: 1
# ── Checkpointing ────────────────────────────────────────────────────────────
enable_checkpointing: False
async_checkpointing: False
# ── Dataset ──────────────────────────────────────────────────────────────────
dataset_type: "synthetic"
For MoE models, also add Dtypes and MoE Runtime sections between Attention and Profiler (see steps 5–6).
Leave enable_checkpointing: False in the YAML. Enable it via CLI passthrough instead:
submit.sh 70b -N 1 -- enable_checkpointing=true
Why CLI, not YAML: The launch system detects enable_checkpointing=true in the passthrough args to switch the output directory from job-based (outputs/12345-JAX-llama2-70b/) to model-based (outputs/llama2-70b/). This is what makes checkpoints persist across job restarts. Setting it only in the YAML would checkpoint to a job-specific directory that is lost on restart.
See docs/job-submission.md for model name aliases (isolating checkpoint dirs for parallel experiments) and async checkpointing caveats.
Use dataset_type: "synthetic" for benchmarking. For real data, set the dataset type and tokenizer:
dataset_type: "hf" # or "grain"
hf_path: "allenai/c4"
hf_data_dir: "en"
tokenizer_path: "meta-llama/Llama-2-7b-hf" # match the model family
Some models require a specific tokenizer — check the MaxText model definition or existing configs. If using a gated model's tokenizer (e.g., meta-llama/*), set hf_access_token or log in via huggingface-cli login.
These settings override MaxText's TPU-oriented defaults for GPU runs:
| Setting | Value | Why |
|---------|-------|-----|
| hardware | "gpu" | base.yml defaults to tpu |
| attention | "cudnn_flash_te" | Flash attention via cuDNN / Transformer Engine |
| use_iota_embed | True | Efficient embedding for GPU |
| shardy | false | Shardy is TPU-oriented |
| Setting | Small (< 20B) | Large (70B+) |
|---------|---------------|--------------|
| remat_policy | "minimal_flash" | "full" |
| scan_layers | False | True |
full rematerialization uses more compute but less memory — required for large models. scan_layers: True reduces compilation time for deep models.
For models not in MaxText's registry, define architecture parameters directly at the top of the config. Wrap in #=== banners and comment out model_name:
#===============================================================================
# Model Architecture: <Name> <Size>
base_emb_dim: 6144
base_num_query_heads: 48
base_num_kv_heads: 8
base_mlp_dim: 32768
base_num_decoder_layers: 64
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 131072
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-5
# MoE params (if applicable):
num_experts: 8
num_experts_per_tok: 2
decoder_block: "mixtral"
#===============================================================================
# ── Run Config ───────────────────────────────────────────────────────────────
base_config: "base.yml"
run_name: "<name>_train_test"
hardware: "gpu"
# NOTE: inline model; omit model_name!
#model_name:
# (rest of config follows the same template)
Use inline when the model doesn't exist in MaxText's registry, or when you need to tweak architecture parameters (e.g., adjusting expert count, hidden dims, or decoder block type) without modifying upstream.
| Problem | Cause | Fix |
|---------|-------|-----|
| sparse_matmul + pipeline parallelism | ValueError at runtime | Set sparse_matmul: false |
| Two -1 axes in same group | Invalid parallelism config | Only one auto-sharded axis per DCN/ICI group |
| ICI axes product != 8 | Doesn't match GPUs per node | Ensure explicit ICI axes multiply to 8 |
| OOM | Batch too large, no quantization, or low mem fraction | Reduce per_device_batch_size, add quantization: "fp8", or raise XLA_PYTHON_CLIENT_MEM_FRACTION (default .85). Memory budget depends on GPU — check with rocm-smi (AMD) or nvidia-smi (NVIDIA). |
| Ambiguous model name | Substring matches multiple .gpu.yml files | Use a more specific short name |
| model_name set in inline config | MaxText tries to load model file AND inline params | Comment out model_name |
| remat_policy: minimal_flash on large model | OOM during training | Switch to remat_policy: full |
| Checkpoints lost on job restart | enable_checkpointing set in YAML instead of CLI | Pass -- enable_checkpointing=true via CLI (see Checkpointing section) |
| Config | Type | Parallelism pattern | Notes |
|--------|------|---------------------|-------|
| llama2-70b | model-name, dense | DCN data / ICI FSDP | Baseline dense config |
| llama3.1-405b | model-name, dense | DCN FSDP / ICI FSDP (both -1) | FP8, very large dense |
| mixtral-8x22b | model-name, MoE | DCN FSDP=4 / ICI expert=8 | FP8, has pipeline params (prefer FSDP) |
| deepseek3-671b | model-name, MoE | DCN FSDP=8 / ICI expert=8 | Large MoE |
| kimi-k2-1t | model-name, MoE | DCN FSDP=8 / ICI expert=8 | Same pattern as deepseek3 |
| qwen3-480b-a35b | model-name, MoE | DCN FSDP=8 / ICI expert=8 | 160 experts, batch=24 sweep-tuned |
| grok-1 | inline, MoE | DCN FSDP=4 / ICI expert=8 | Inline architecture, FP8 |
| grok-2 | inline, MoE | DCN FSDP=4 / ICI expert=8 | Inline, FP8, softcapping, scan_layers=false (XLA ROCm bug) |
| default | model-name, small | DCN data / ICI FSDP | Smoke test, HF dataset |
See docs/model-configs.md for config resolution and CLI override details.
tools
Comprehensive pre-commit verification checklist with five independent responsibilities. (1) Launcher path coverage - verify a change to any launcher-chain file preserves correct behavior across all 16 combinations of entry point × launch mode × stack (Steps 1-4 + 5.1). (2) Ancillary scripts smoke - syntax / help / read-only / caller checks for any `.sh` or `.py` outside the launcher chain (Step 5.2; covers analysis utilities, sourced libraries, debug helpers, sweep tooling). (3) Code quality and design review (Step 6) - propose-first surface of code smells (duplication, long functions, magic numbers, deep nesting, unclear naming, primitive obsession, etc.) and design-decay signals (5th case in a switch, N-th env-var read, hand-rolled retry loops); auto-fix mechanical findings, hold design-shaped ones for explicit go-ahead. (4) Docs / comments / format-consistency (Step 7) - check any commit for stale prose, trailing-comment alignment drift, broken anchors / missing files in links, drifted cross-references, and this skill itself drifting from the code it describes. (5) Sensitive-info leak scan (Step 8) - cluster hostnames, internal IPs, vendor mount paths, hard-coded credentials, internal job IDs; final security gate. Trigger keywords - "verify all launcher paths", "trace launcher chain", "audit entry × launch × stack", "path coverage", "(entry × launch × stack) matrix", "post-launch teardown verification", "pre-commit audit", "before commit", "ready to commit", "verify scripts / utils not broken", "smoke-test the changed scripts", "any utility script broken", "code quality", "design review", "code smells", "tighten and polish", "avoid quality decay", "revisit design choice", "scrub leaks", "check for sensitive info before commit", "any docs or skills need update", "any stale comments", "any inaccurate comments", "comment alignment", "link policy", "broken anchors". Use when modifying `_train.sh`, `_train_with_ray.sh`, `_ray_actor.py`, `_container.sh`, `_job.sbatch`, `_k8s_job.sh`, `in_container_run.sh`, `run_local.sh`, `submit.sh`, `k8s_submit.sh`, `utils/run_setup.sh`, `utils/ray_cluster.sh`, `utils/monkey_patch_maxtext.py`, `utils/coredump.sh`, `utils/stage_timeout.sh`, or anywhere else in the launcher chain. Also use proactively before opening any PR (Steps 5.2, 6, 7, 8 apply universally to all changes that touch code / docs / comments), when investigating a path-specific bug ("this only happens in K8s + 1-gpu-per-process"), after adding a new entry point / launch mode / stack option, after touching any analysis utility (`utils/analyze_job.py`, `utils/perf_server.py`, `utils/profile_drill.py`, `utils/slurm_job_monitor.sh`, etc.), or after editing any doc or skill in the repo (Step 7 catches cross-reference drift).
testing
Find the XLA flag / NCCL env-var combination that maximizes steady-state TGS for one (model × parallelism) cell. Produces an evidence-backed leaderboard, mechanistic explanation of the winning flag, and a deployment recipe. Use when the user asks to tune XLA flags, tune NCCL, find best collective-permute / all-gather threshold, optimize FSDP/PP/TP, close a parallelism-vs-parallelism throughput gap, or sweep cross-iteration prefetch / overlap-limit / async-stream-priority knobs for a specific model.
testing
Diagnose training job incidents and check cluster health using the per-job Prometheus TSDB. Use when the user asks to diagnose a failure root cause, check GPU/network health, query Prometheus metrics, investigate a hang, or when the triage skill recommends deeper TSDB analysis.
testing
Use Telegram as the agent's I/O channel. Once triggered, the agent enters a REPL state — reading instructions from TG, executing them, printing results back to TG, and looping. Use when the user asks to be notified, messaged, or alerted via Telegram, or wants to interact with the agent through TG. This is a cross-cutting skill — other skills (batch-sweep, model-config, job-triage) can trigger it when the user explicitly requests it.