skills/skillxiv-v0.0.2-claude-opus-4.6/dynamic-mask-sparse-attention/SKILL.md
Sparse attention mechanism combining content-aware and position-aware sparse patterns through dynamic masking. Achieves 10x speedup while maintaining model quality on long-context benchmarks through hardware-friendly implementation.
npx skillsauth add ADu2021/skillXiv dynamic-mask-sparse-attentionInstall this skill globally with one command. Works with Claude Code, Cursor, and Windsurf.
3 of 9 scanners reported clean
Some scanners were skipped, did not run, or reported a non-clean status. Review each row below.
Dynamic Mask Sparse Attention (DMA) addresses the quadratic complexity bottleneck in standard self-attention for long-context language models. By combining position-aware and content-aware sparse attention through differentiable dynamic masking, DMA achieves dramatic speedups while preserving or improving model performance.
The fundamental insight is that not all token pairs need full attention computation. DMA:
The framework consists of:
Step 1: Implement content-aware mask generation
Learn which tokens are important for attention:
import torch
import torch.nn as nn
from typing import Tuple, Optional
import math
class ContentAwareMaskGenerator(nn.Module):
"""Generates attention masks based on content importance"""
def __init__(self, hidden_size: int, num_heads: int, sparsity: float = 0.9):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.sparsity = sparsity # Fraction of attention to mask out
self.head_dim = hidden_size // num_heads
# Learnable content scorer
self.value_scorer = nn.Linear(self.head_dim, 1)
self.temperature = nn.Parameter(torch.tensor(1.0))
def forward(self, values: torch.Tensor,
queries: torch.Tensor) -> torch.Tensor:
"""
Generate content-aware attention mask.
Args:
values: (batch, num_heads, seq_len, head_dim)
queries: (batch, num_heads, seq_len, head_dim)
Returns:
Mask of shape (batch, num_heads, seq_len, seq_len)
"""
batch_size, num_heads, seq_len, head_dim = values.shape
# Score token importance using values
# Tokens with high-magnitude values are important
value_importance = torch.norm(values, dim=-1) # (batch, heads, seq_len)
# Score query complexity
query_complexity = torch.norm(queries, dim=-1) # (batch, heads, seq_len)
# Combined importance: higher for queries needing attention
importance = query_complexity.unsqueeze(-1) * value_importance.unsqueeze(-2)
# Shape: (batch, heads, seq_len, seq_len)
# Compute attention mask: keep top tokens
num_keep = max(1, int(seq_len * (1 - self.sparsity)))
# For each query, select top-k values to attend to
mask = torch.zeros_like(importance)
for b in range(batch_size):
for h in range(num_heads):
for q in range(seq_len):
# Top-k indices for this query
topk_vals, topk_idx = torch.topk(
importance[b, h, q, :],
k=num_keep
)
mask[b, h, q, topk_idx] = 1.0
return mask
def compute_loss(self, mask: torch.Tensor) -> torch.Tensor:
"""
Regularization loss to encourage sparsity.
"""
# L1 norm encourages sparse masks
sparsity_loss = torch.abs(mask).sum() / mask.numel()
return sparsity_loss
Step 2: Implement position-aware sparse patterns
Create efficient fixed sparsity patterns based on position:
class PositionAwarePattern(nn.Module):
"""Fixed sparse attention patterns based on position"""
def __init__(self, pattern_type: str = 'local'):
super().__init__()
self.pattern_type = pattern_type
def get_local_pattern(self, seq_len: int,
window_size: int = 64) -> torch.Tensor:
"""
Local attention: each token attends to nearby tokens.
Args:
seq_len: Sequence length
window_size: Local window size
Returns:
Mask of shape (seq_len, seq_len)
"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
# Attend to window around position i
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2)
mask[i, start:end] = True
return mask
def get_strided_pattern(self, seq_len: int,
stride: int = 8) -> torch.Tensor:
"""
Strided attention: each token attends to every stride-th token.
Reduces quadratic complexity to linear while maintaining coverage.
"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
# Attend to every stride-th position
mask[i, ::stride] = True
# Also attend to surrounding local window
mask[i, max(0, i-2):min(seq_len, i+3)] = True
return mask
def get_dilated_pattern(self, seq_len: int,
dilation: int = 4) -> torch.Tensor:
"""
Dilated attention: attend with dilated receptive field.
Captures long-range dependencies efficiently.
"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
# Attend to local window
mask[i, max(0, i-4):min(seq_len, i+5)] = True
# Attend to dilated positions
for offset in range(-seq_len, seq_len, dilation):
j = i + offset
if 0 <= j < seq_len:
mask[i, j] = True
return mask
def forward(self, seq_len: int,
device: torch.device = None) -> torch.Tensor:
"""Get sparse pattern for given sequence length"""
if self.pattern_type == 'local':
pattern = self.get_local_pattern(seq_len)
elif self.pattern_type == 'strided':
pattern = self.get_strided_pattern(seq_len)
elif self.pattern_type == 'dilated':
pattern = self.get_dilated_pattern(seq_len)
else:
# Full attention
pattern = torch.ones(seq_len, seq_len, dtype=torch.bool)
if device is not None:
pattern = pattern.to(device)
return pattern
Step 3: Combine masks into dynamic sparse attention
Merge content and position masks for efficient computation:
class DynamicMaskAttention(nn.Module):
"""Combines content and position masks for sparse attention"""
def __init__(self, hidden_size: int, num_heads: int,
position_pattern: str = 'local',
sparsity: float = 0.9):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.content_mask_gen = ContentAwareMaskGenerator(
hidden_size, num_heads, sparsity
)
self.position_pattern = PositionAwarePattern(position_pattern)
# Projection layers
self.query_proj = nn.Linear(hidden_size, hidden_size)
self.key_proj = nn.Linear(hidden_size, hidden_size)
self.value_proj = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, dict]:
"""
Compute sparse attention with dynamic masking.
Args:
hidden_states: (batch, seq_len, hidden_size)
attention_mask: Optional padding mask
Returns:
(output, attention_stats)
"""
batch_size, seq_len, hidden_size = hidden_states.shape
# Project to Q, K, V
query = self.query_proj(hidden_states)
key = self.key_proj(hidden_states)
value = self.value_proj(hidden_states)
# Reshape for multi-head attention
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim)
query = query.transpose(1, 2) # (batch, heads, seq_len, head_dim)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim)
key = key.transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim)
value = value.transpose(1, 2)
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
# (batch, heads, seq_len, seq_len)
# Generate content-aware mask
content_mask = self.content_mask_gen(value, query)
# Get position-aware pattern
pos_pattern = self.position_pattern(seq_len, device=hidden_states.device)
# Combine masks: attend where BOTH patterns say to attend
# (intersection is stricter, keeps both content and position constraints)
pos_pattern = pos_pattern.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
combined_mask = content_mask & pos_pattern
# Apply combined mask to scores
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(~combined_mask, mask_value)
# Softmax and attention
attention_weights = torch.softmax(scores, dim=-1)
# Masked attention weights have zeros where mask is False
attention_weights = attention_weights.masked_fill(~combined_mask, 0.0)
# Apply attention to values
context = torch.matmul(attention_weights, value)
# (batch, heads, seq_len, head_dim)
# Reshape back
context = context.transpose(1, 2)
context = context.contiguous().view(batch_size, seq_len, hidden_size)
# Output projection
output = self.out_proj(context)
# Compute statistics
stats = {
'mask_density': combined_mask.float().mean().item(),
'content_density': content_mask.float().mean().item(),
'pos_density': pos_pattern.float().mean().item(),
'sparsity_loss': self.content_mask_gen.compute_loss(combined_mask)
}
return output, stats
Step 4: Implement hardware-optimized sparse kernels
Create efficient implementations for sparse attention computation:
class SparseAttentionKernel:
"""Hardware-optimized sparse attention computation"""
@staticmethod
def sparse_matmul(query: torch.Tensor,
key: torch.Tensor,
mask: torch.Tensor,
scaling_factor: float) -> torch.Tensor:
"""
Efficient sparse matrix multiplication for attention.
Uses mask to avoid computing attention for masked positions.
"""
batch_size, num_heads, seq_len, head_dim = query.shape
# Standard attention with masking (simplified)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(head_dim)
# Convert boolean mask to float and apply
mask_float = mask.float()
mask_value = torch.finfo(scores.dtype).min
# Zero out masked positions
scores = scores.masked_fill(~mask, mask_value)
# Apply scaling to maintain variance
attention = torch.softmax(scores, dim=-1)
return attention
@staticmethod
def compute_flops_reduction(seq_len: int,
mask_density: float) -> float:
"""
Estimate FLOPs reduction from sparsity.
Args:
seq_len: Sequence length
mask_density: Fraction of attention computed
Returns:
FLOPs reduction ratio (1.0 = no reduction, lower = more sparse)
"""
full_flops = seq_len * seq_len # O(n^2)
sparse_flops = seq_len * seq_len * mask_density
reduction = full_flops / sparse_flops
return reduction
Step 5: Integrate into training loop
Train the sparse attention mechanism end-to-end:
class DMATrainer:
"""Trains Dynamic Mask Attention with full architecture"""
def __init__(self, model, attention_module: DynamicMaskAttention,
learning_rate: float = 1e-4):
self.model = model
self.attention = attention_module
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate
)
def training_step(self, batch: Dict) -> Dict:
"""
Single training step with sparse attention.
Args:
batch: Contains 'input_ids', 'labels'
Returns:
Training metrics
"""
input_ids = batch['input_ids']
labels = batch.get('labels', input_ids)
# Forward pass with sparse attention
hidden_states = self.model.embed(input_ids)
# Apply sparse attention
output, att_stats = self.attention(hidden_states)
# Compute language modeling loss
logits = self.model.lm_head(output)
lm_loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1)
)
# Sparsity regularization loss
sparsity_loss = att_stats['sparsity_loss']
# Total loss: LM loss + sparsity regularization
total_loss = lm_loss + 0.01 * sparsity_loss
# Backward pass
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
return {
'lm_loss': lm_loss.item(),
'sparsity_loss': sparsity_loss.item(),
'total_loss': total_loss.item(),
'mask_density': att_stats['mask_density'],
'flops_reduction': SparseAttentionKernel.compute_flops_reduction(
input_ids.size(1),
att_stats['mask_density']
)
}
def train_epoch(self, dataloader, num_epochs: int = 3):
"""Train for multiple epochs"""
for epoch in range(num_epochs):
total_metrics = {}
for batch_idx, batch in enumerate(dataloader):
metrics = self.training_step(batch)
# Accumulate metrics
for key, value in metrics.items():
if key not in total_metrics:
total_metrics[key] = []
total_metrics[key].append(value)
if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}: Loss={metrics['total_loss']:.4f}, "
f"Density={metrics['mask_density']:.3f}, "
f"Speedup={metrics['flops_reduction']:.1f}x")
# Print epoch summary
avg_metrics = {k: sum(v) / len(v) for k, v in total_metrics.items()}
print(f"\nEpoch {epoch} Summary:")
for key, value in avg_metrics.items():
print(f" {key}: {value:.4f}")
When to use Dynamic Mask Sparse Attention:
When NOT to use Dynamic Mask Sparse Attention:
Key hyperparameters:
sparsity: 0.85-0.95 typical (85-95% of attention masked)position_pattern: 'local' for general, 'strided' for efficiencysparsity_loss_weight: 0.001-0.01 for regularizationwindow_size (local): 64-256 typicalExpected characteristics:
Performance benchmarks:
Trainable Dynamic Mask Sparse Attention. arXiv:2508.02124
testing
Uses flow maps as look-ahead operators to enable principled reward-guided diffusion by predicting trajectory endpoints at any denoising step. Deploy when applying rewards or preferences to diffusion trajectories with meaningful gradients throughout generation.
testing
Train language models where each expert learns independently on closed datasets, enabling flexible inference with selective data inclusion or exclusion. 41% performance improvement while allowing users to opt out of specific data sources without retraining.
data-ai
Understand how token generation flexibility in diffusion LMs paradoxically constrains reasoning, as models exploit ordering flexibility to avoid uncertain tokens, and apply simplified approaches that preserve parallel decoding benefits. Use when optimizing diffusion-based language models for reasoning tasks.
devops
Enable LLM agents to improve continuously during deployment by constructing structured experience libraries through self-reflection on successes and failures—achieving 23% improvement on reasoning without gradient-based parameter updates or external training.