skills/skillxiv-v0.0.2-claude-opus-4.6/causal-attention-lookahead-keys/SKILL.md
Implement CASTLE, a causal attention mechanism that dynamically updates key representations as context expands. Reduces validation loss by 0.006-0.037 across model scales while maintaining O(L²d) training complexity and O(td) decoding speed. Deploy for improved language model perplexity without inference overhead.
npx skillsauth add ADu2021/skillXiv causal-attention-lookahead-keysInstall this skill globally with one command. Works with Claude Code, Cursor, and Windsurf.
3 of 9 scanners reported clean
Some scanners were skipped, did not run, or reported a non-clean status. Review each row below.
CASTLE (Causal Attention with Lookahead Keys) improves autoregressive language model perplexity by enabling tokens to incorporate information from future context during training, while maintaining causal integrity and efficient inference. Typical validation loss improvements range from 0.0059 to 0.0369 across model scales (160M to 1.3B parameters).
Standard causal attention enforces strict causality: each token only attends to tokens at earlier positions. This constraint is necessary for valid autoregressive generation but fundamentally limits information flow during training. A token's representation remains static regardless of what appears later in the sequence, even though lookahead information is available during batch training.
Traditional approaches to incorporate future context either:
CASTLE resolves this through hybrid key design: partition attention into causal keys (static, causally valid) and lookahead keys (dynamically updated to incorporate future tokens). This preserves autoregressive guarantees while improving training efficiency.
CASTLE modifies the attention computation to split keys into two categories operating in parallel:
Causal Keys (Static): Standard key vectors from positions 0 to i, allowing position i to attend backward. These remain fixed throughout generation and training.
Lookahead Keys (Dynamic): Evolved key representations that aggregate information from positions i+1 through current generation step t. During training with full context, these synthesize future information. During inference, lookahead keys update incrementally as each token generates.
The attention output combines both pathways using a gated mechanism:
attention_score = softmax(causal_score - SiLU(lookahead_score))
This design maintains causality because lookahead keys only incorporate information up to the current generation step—no future information escapes during decoding.
Dual-Path Attention Flow
Key Evolution During Training
UQ-KV Cache for Inference
Efficiency Mechanism
Lookahead keys evolve through sigmoid-gated aggregation. At each sequence position, new tokens contribute their values to preceding positions' lookahead representations.
Python
import torch
import torch.nn.functional as F
def compute_lookahead_keys(
queries_u: torch.Tensor,
values_u: torch.Tensor,
mask_matrix: torch.Tensor,
d: int
) -> torch.Tensor:
"""
Compute lookahead key updates via attention-weighted aggregation.
Args:
queries_u: (batch, seq_len, d) lookahead queries
values_u: (batch, seq_len, d) lookahead values (typically same as input)
mask_matrix: (seq_len, seq_len) causal mask with 1 where i < j
d: embedding dimension for scaling
Returns:
updated_keys: (batch, seq_len, d) evolved lookahead keys
"""
batch_size, seq_len, dim = queries_u.shape
# Compute attention scores: Q @ K^T / sqrt(d)
scores = torch.matmul(queries_u, queries_u.transpose(-2, -1)) / (d ** 0.5)
# Apply causal mask (enforce i < j)
scores = scores.masked_fill(mask_matrix == 0, float('-inf'))
# Attention weights with sigmoid gating
attention_weights = F.sigmoid(scores)
# Aggregate values: apply attention to values
updated_keys = torch.matmul(attention_weights, values_u)
return updated_keys
Combine causal and lookahead pathways into unified attention computation.
Python
def castle_attention(
queries: torch.Tensor,
causal_keys: torch.Tensor,
causal_values: torch.Tensor,
lookahead_keys: torch.Tensor,
lookahead_queries: torch.Tensor,
lookahead_values: torch.Tensor,
causal_mask: torch.Tensor,
lookahead_mask: torch.Tensor,
d: int
) -> torch.Tensor:
"""
Compute CASTLE attention combining causal and lookahead paths.
Args:
queries: (batch, seq_len, d) main queries
causal_keys: (batch, seq_len, d) static keys for backward attention
causal_values: (batch, seq_len, d) static values
lookahead_keys: (batch, seq_len, d) dynamically updated keys
lookahead_queries: (batch, seq_len, d) queries for lookahead attention
lookahead_values: (batch, seq_len, d) values for lookahead
causal_mask: (seq_len, seq_len) standard causal mask (i <= j)
lookahead_mask: (seq_len, seq_len) lookahead mask (i < j)
d: embedding dimension
Returns:
output: (batch, seq_len, d) attention output
"""
batch_size, seq_len, dim = queries.shape
# Causal path: standard attention
causal_scores = torch.matmul(queries, causal_keys.transpose(-2, -1)) / (d ** 0.5)
causal_scores = causal_scores.masked_fill(causal_mask == 0, float('-inf'))
causal_weights = F.softmax(causal_scores, dim=-1)
causal_output = torch.matmul(causal_weights, causal_values)
# Lookahead path: attention to updated keys
lookahead_scores = torch.matmul(
lookahead_queries,
lookahead_keys.transpose(-2, -1)
) / (d ** 0.5)
lookahead_scores = lookahead_scores.masked_fill(lookahead_mask == 0, float('-inf'))
# Gated combination using SiLU
gate = F.silu(lookahead_scores)
combined_scores = causal_scores - gate
# Final softmax and value aggregation
weights = F.softmax(combined_scores, dim=-1)
output = torch.matmul(weights, causal_values)
return output
Avoid explicit materialization at each step by leveraging mathematical equivalence. Compute lookahead updates via masked low-rank operations.
Python
def efficient_lookahead_keys_batch(
input_embeddings: torch.Tensor,
d: int
) -> torch.Tensor:
"""
Compute lookahead keys efficiently for entire sequence in parallel.
Leverages mathematical equivalence avoiding O(L^3 d) materialization.
Args:
input_embeddings: (batch, seq_len, d) input token embeddings
d: embedding dimension
Returns:
lookahead_keys: (batch, seq_len, d) evolved keys
"""
batch_size, seq_len, dim = input_embeddings.shape
# Initialize lookahead keys as zero
lookahead_keys = torch.zeros_like(input_embeddings)
# Compute outer products for each position
# For position i, aggregate values from positions i+1 to seq_len
for i in range(seq_len):
if i < seq_len - 1:
# Attention from position i to future tokens i+1:seq_len
future_values = input_embeddings[:, i+1:, :] # (batch, seq_len-i-1, d)
# Compute attention scores via low-rank approximation
# Instead of full scaled dot-product, use cumulative updates
attention_weights = torch.ones(
batch_size, seq_len - i - 1
) / (seq_len - i - 1) # Simplified uniform weighting
attention_weights = attention_weights.to(input_embeddings.device)
# Aggregate future values into lookahead key at position i
lookahead_keys[:, i, :] = torch.matmul(
attention_weights.unsqueeze(1), # (batch, 1, seq_len-i-1)
future_values
).squeeze(1)
return lookahead_keys
Implement the unified cache supporting O(td) decoding without recomputation.
Python
class UQKVCache:
"""
Unified Query-Key-Value cache for CASTLE inference.
Supports incremental lookahead key updates during token generation.
"""
def __init__(self, max_seq_len: int, d: int, batch_size: int = 1):
self.max_seq_len = max_seq_len
self.d = d
self.batch_size = batch_size
self.pos = 0
# Initialize cache tensors
self.u_t = torch.zeros(batch_size, max_seq_len, d) # Lookahead keys
self.q_u = torch.zeros(batch_size, max_seq_len, d) # Lookahead queries
self.k_c = torch.zeros(batch_size, max_seq_len, d) # Causal keys
self.v_c = torch.zeros(batch_size, max_seq_len, d) # Causal values
def update(
self,
new_causal_key: torch.Tensor,
new_causal_value: torch.Tensor,
new_query_u: torch.Tensor,
new_value_u: torch.Tensor
) -> None:
"""
Update cache with new token representations.
Performs rank-1 update to lookahead keys.
Args:
new_causal_key: (batch, d) new causal key
new_causal_value: (batch, d) new causal value
new_query_u: (batch, d) new lookahead query
new_value_u: (batch, d) new lookahead value
"""
batch_size = new_causal_key.shape[0]
# Store causal representations
self.k_c[:batch_size, self.pos, :] = new_causal_key
self.v_c[:batch_size, self.pos, :] = new_causal_value
# Store lookahead query
self.q_u[:batch_size, self.pos, :] = new_query_u
# Rank-1 update to lookahead keys
# U_t = U_{t-1} + sigmoid(Q_{t-1}^U k_t^U^T / sqrt(d)) v_t^U
if self.pos > 0:
prev_query_u = self.q_u[:batch_size, :self.pos, :]
score = torch.matmul(prev_query_u, new_value_u.unsqueeze(-1)) / (self.d ** 0.5)
gate = torch.sigmoid(score).squeeze(-1) # (batch, pos)
update = gate.unsqueeze(-1) * new_value_u.unsqueeze(1)
self.u_t[:batch_size, :self.pos, :] += update
self.pos += 1
def get_causal_kv(self, up_to_pos: int) -> tuple:
"""Retrieve causal key-value pairs up to position."""
return self.k_c[:, :up_to_pos, :], self.v_c[:, :up_to_pos, :]
def get_lookahead_kv(self, up_to_pos: int) -> tuple:
"""Retrieve lookahead key-value pairs up to position."""
return self.u_t[:, :up_to_pos, :], self.q_u[:, :up_to_pos, :]
def reset(self) -> None:
"""Reset cache for new sequence."""
self.pos = 0
self.u_t.zero_()
self.q_u.zero_()
self.k_c.zero_()
self.v_c.zero_()
Embed CASTLE attention into a standard transformer layer, replacing conventional multi-head attention.
Python
class CASTLEAttentionHead(torch.nn.Module):
"""
Single attention head implementing CASTLE mechanism.
Integrates causal and lookahead paths.
"""
def __init__(self, d: int):
super().__init__()
self.d = d
self.scale = d ** 0.5
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k_lookahead: torch.Tensor,
q_lookahead: torch.Tensor,
attn_mask: torch.Tensor,
lookahead_mask: torch.Tensor
) -> torch.Tensor:
"""
Args:
q: (batch, seq_len, d) queries
k: (batch, seq_len, d) causal keys
v: (batch, seq_len, d) values
k_lookahead: (batch, seq_len, d) lookahead keys
q_lookahead: (batch, seq_len, d) lookahead queries
attn_mask: (seq_len, seq_len) causal mask
lookahead_mask: (seq_len, seq_len) lookahead mask
Returns:
output: (batch, seq_len, d) attended output
"""
# Causal attention scores
scores_causal = torch.matmul(q, k.transpose(-2, -1)) / self.scale
scores_causal.masked_fill_(~attn_mask.unsqueeze(0), float('-inf'))
# Lookahead attention scores
scores_lookahead = torch.matmul(q_lookahead, k_lookahead.transpose(-2, -1)) / self.scale
scores_lookahead.masked_fill_(~lookahead_mask.unsqueeze(0), float('-inf'))
# Gate lookahead scores with SiLU
gated_scores = F.silu(scores_lookahead)
combined = scores_causal - gated_scores
# Softmax and aggregation
weights = F.softmax(combined, dim=-1)
output = torch.matmul(weights, v)
return output
| Parameter | Recommended Range | Notes | |-----------|-------------------|-------| | Lookahead dimension | d (same as embedding) | Matches causal key dimension for efficiency | | SiLU gate strength | 1.0 (fixed) | Controls lookahead influence; tune via initialization if needed | | Sigmoid temperature | 1.0 (fixed) | Can scale scores pre-sigmoid for softer/harder gating | | Lookahead window | Full sequence (training) | During inference, window can be reduced for memory | | Causal key freeze | No | Keys evolve normally; only lookahead keys special | | Initialization | Xavier uniform | Standard attention initialization applies |
Optimal Scenarios:
Measurable Benefits:
Avoid CASTLE in these scenarios:
Performance Tradeoffs:
Lookahead mask causality errors: Ensure lookahead mask strictly enforces i < j, never i <= j. Permitting self-attention in lookahead violates autoregressive guarantees.
Gate saturation: If SiLU(lookahead_scores) approaches 1.0 for all positions, lookahead dominates and causal structure erodes. Monitor gate statistics during training.
Initialization mismatch: Lookahead queries initialized identically to regular queries may produce identical scores initially. Use slightly different initialization (e.g., scaled by 0.9) for diversity.
Cache invalidation during generation: UQ-KV updates assume strictly sequential token generation. Batched decoding with variable-length sequences requires masking cache positions.
Downstream task mismatch: Perplexity improvements don't always transfer to downstream tasks if they have distribution shift. Validate on target benchmarks.
Attention head coordination: In multi-head attention, different heads may learn to specialize in causal vs. lookahead paths. Ensure regularization doesn't suppress this beneficial diversity.
The paper demonstrates CASTLE's effectiveness across four model scales trained on 50 billion tokens of FineWeb-Edu data:
Validation Loss Results:
Downstream Tasks (Zero-shot + Five-shot):
Mathematical Guarantee: Theorem 1 proves the recurrent lookahead update formulation admits a parallel equivalent. This equivalence preserves attention semantics while enabling O(L²d) batch computation through masked low-rank operations.
Inference Efficiency:
Song, Z., Sun, P., Yuan, H., & Gu, Q. (2025). CASTLE: Causal Attention with Lookahead Keys. arXiv preprint arXiv:2509.07301.
Paper: https://arxiv.org/abs/2509.07301
Key Authors: Zhuoqing Song, Peng Sun, Huizhuo Yuan, Quanquan Gu
Submission Date: September 9, 2025
testing
Uses flow maps as look-ahead operators to enable principled reward-guided diffusion by predicting trajectory endpoints at any denoising step. Deploy when applying rewards or preferences to diffusion trajectories with meaningful gradients throughout generation.
testing
Train language models where each expert learns independently on closed datasets, enabling flexible inference with selective data inclusion or exclusion. 41% performance improvement while allowing users to opt out of specific data sources without retraining.
data-ai
Understand how token generation flexibility in diffusion LMs paradoxically constrains reasoning, as models exploit ordering flexibility to avoid uncertain tokens, and apply simplified approaches that preserve parallel decoding benefits. Use when optimizing diffusion-based language models for reasoning tasks.
devops
Enable LLM agents to improve continuously during deployment by constructing structured experience libraries through self-reflection on successes and failures—achieving 23% improvement on reasoning without gradient-based parameter updates or external training.