skills/nlp/tpu-multicore-inference/SKILL.md
Distributes inference across multiple TPU cores using torch_xla, each core writing a CSV shard, then merges shards via groupby mean.
npx skillsauth add wenmin-wu/ds-skills nlp-tpu-multicore-inferenceInstall this skill globally with one command. Works with Claude Code, Cursor, and Windsurf.
3 of 9 scanners reported clean
Some scanners were skipped, did not run, or reported a non-clean status. Review each row below.
A single TPU v3-8 has 8 cores. Running inference on one core wastes 7/8 of available compute. Use torch_xla multiprocessing to distribute the test set across all cores with DistributedSampler. Each core writes predictions to a separate CSV shard; merge by averaging overlapping IDs (from sampler padding) afterward.
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch.utils.data import DataLoader, DistributedSampler
def _mp_fn(rank, flags):
device = xm.xla_device()
model = MyModel().to(device)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()
sampler = DistributedSampler(
test_dataset, num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(), shuffle=False)
loader = DataLoader(test_dataset, batch_size=32,
sampler=sampler, drop_last=False)
preds = []
for batch in loader:
with torch.no_grad():
out = model(batch["input_ids"].to(device))
preds.append(out.cpu())
df = pd.DataFrame({"id": ids, "pred": torch.cat(preds).numpy()})
df.to_csv(f"shard_{rank}.csv", index=False)
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method="fork")
# Merge shards — average duplicates from sampler padding
submission = (pd.concat([pd.read_csv(f"shard_{i}.csv") for i in range(8)])
.groupby("id").mean().reset_index())
xmp.spawn (one per TPU core)DistributedSampler for its slice of the test set"fork" on TPU VMs; "spawn" on Colabgroupby("id").mean() handles duplicate rows from sampler paddingdata-ai
Scaled Pinball Loss (SPL) metric for evaluating quantile forecasts, normalized by mean absolute successive differences of training data
data-ai
Walk backward through a time series and multiplicatively rescale segments when jumps exceed a fraction of the running mean to correct data collection anomalies
testing
Transform forecasting target to next/current ratio minus one so that optimizing MAE or squared error implicitly minimizes SMAPE
tools
Convert point forecasts to prediction intervals by scaling with logit-transformed quantile ratios passed through a Normal CDF