AI / ML Infrastructure
Distributed Checkpointing System
A 26-day training run on ~100,000 GPUs hits a
hardware fault roughly every 30 minutes. A checkpoint
is the only thing standing between that fault and losing hours — or
days — of compute, because in synchronous training a
single dead rank stalls the entire job. The hard part is that
the full training state is multi-terabyte, so the
naive "pause everyone and write it to disk" costs minutes of stall
every interval and quietly burns MFU. A good
checkpointing system resolves that tension along four axes: snapshot
asynchronously so training barely pauses, write it
sharded so bandwidth scales with the fleet, stage it
across a tiered hierarchy so the common failure
recovers from a neighbour and not from cold storage, and publish it
consistently so you never restore from a half-written
mess.
Requirements
Checkpointing is failure insurance for a batch job. The premium is the overhead you pay to take a snapshot; the payout is that any crash only rolls you back to the last save instead of to step zero. At 100k-GPU scale the question is never whether to checkpoint but how cheaply and how often — because the fault rate is measured in minutes and a restart from scratch can mean a week of redone work.
The subtlety is that a resumable checkpoint is far more than the model weights. To continue a run deterministically you must capture the entire state that influences the next step:
- Model weights — the parameters themselves (often bf16/fp16).
-
Optimizer state — for Adam, the fp32 master weights
plus the first-moment
mand second-momentvtensors. This usually dwarfs the weights. - RNG state — per-rank CPU and CUDA generator state, so dropout, data augmentation, and sampling resume identically.
- Dataloader cursor — exactly which samples have been consumed (shard index + offset), so you neither skip nor replay data.
- Step / epoch counter, LR-scheduler state, and the AMP loss-scaler — all of which change with the schedule.
Functional requirements
- Save the complete training state at a chosen step as a single logical, atomic checkpoint.
- Restore deterministically and resume — statistically equivalent always, bit-wise on identical hardware.
-
Reshard on restore: load a checkpoint onto a
different topology when the world size or the
TP x PP x DPdegrees change. -
Version & retain: keep a rolling last-
Kplus periodic milestones; garbage-collect the rest. - Verify & publish atomically: checksum every shard and only ever expose complete checkpoints to readers.
Non-functional requirements
- Low overhead / minimal stall — the training loop must barely pause; checkpointing cost should be a small single-digit percent of wall-clock.
- High write bandwidth — land a multi-terabyte checkpoint durably in seconds, not hours.
- Fast restore — parallel reads, and prefer the fastest tier that still has the data.
- Durability — survive node, rack, and (for milestones) datacenter loss.
- Scalability — aggregate cost grows sub-linearly with the fleet; no central writer.
Recovery is throughput
Every minute spent stalled for a snapshot, and every step redone after a crash, is a minute the world's most expensive hardware sits idle. So the two metrics that matter are checkpoint overhead (how long training waits to save) and recovery cost (snapshot interval lost on a crash, plus restore time). Everything below is an attack on one of those two numbers.
Scale & back-of-the-envelope
Sizing a checkpoint starts from one formula:
checkpoint_bytes = (params + optimizer_state) x
bytes_per_element. For mixed-precision Adam the persisted footprint is about
14 bytes per parameter — bf16 weights (2) + fp32
master copy (4) + Adam m (4) + Adam v (4).
Gradients are transient and are not saved. From there
everything else — write bandwidth, snapshot stall, and how often to
checkpoint — falls out.
| Quantity | Figure | Math / why it matters |
|---|---|---|
| Footprint / param | ~14 bytes |
bf16 weights (2) + fp32 master (4) + Adam m (4) + Adam v (4); grads not saved. |
| 1T-param checkpoint | ~14 TB |
1e12 x 14 = full training state for one save. |
| 175B checkpoint | ~2.45 TB |
1.75e11 x 14 — still far past any single host's RAM. |
| 7B checkpoint | ~98 GB |
7e9 x 14 — even "small" models are big on disk. |
| Naive single-stream write | ~3.9 hours |
14 TB / 1 GB/s — gather-to-rank-0 is hopeless on its own. |
| Sharded write | ~28 s @ 500 GB/s |
Aggregate bandwidth scales with writer count until storage caps it. |
| Per-rank shard | ~1.75 GB |
14 TB / 8,000 writers (after de-duplicating DP replicas). |
| Snapshot stall (host copy) | ~70 ms |
1.75 GB / ~25 GB/s PCIe into pinned host RAM — all training waits for. |
| Fleet MTBF | ~30 min |
node_MTBF / N at 100k GPUs (see the
training-cluster design).
|
| Optimal interval | tau ~ sqrt(2 x C x M) |
Young/Daly; C = checkpoint cost, M = MTBF. C=30s, M=1800s -> ~5.5 min. |
| Min wasted fraction | ~ sqrt(2 x C / M) |
Async shrinks C to seconds -> fewer % wasted and you can save more often. |
| Restore from peer RAM | ~tens of ms |
1.75 GB over ~400 Gb/s IB vs seconds-to-minutes from the object store. |
How often to checkpoint
Checkpoint too rarely and a crash throws away a long stretch of
compute; checkpoint too often and the saves themselves dominate. The
classic
Young/Daly result balances the two: the optimal
interval is roughly tau ~ sqrt(2 x C x M), where
C is the cost of taking one checkpoint and
M is the mean time between failures. The unavoidable
overhead at that optimum is about sqrt(2 x C / M). On a
crash you lose, on average, half an interval of work (tau / 2) plus the restore time.
Why async changes the whole equation
The lever in Young/Daly is C. If checkpointing blocks
training for the full multi-terabyte write, C is
minutes and the optimum forces long intervals (more redo on a
crash). If you make the on-critical-path cost just the fast
host snapshot — tens of milliseconds — then C collapses
to seconds, the optimal interval shrinks, and the wasted fraction
drops from double digits to low single digits.
Cheap checkpoints let you checkpoint often, which is what
actually bounds your loss on a crash.
Asynchronous checkpointing: the core deep dive
A synchronous checkpoint blocks every rank while terabytes stream to
durable storage — a multi-minute stall on every interval that
translates directly into lost MFU. The decisive idea is
to
split the operation along its two very different speeds: a fast copy out of GPU memory, and a slow flush to storage that
runs in the background while training continues.
Two phases
-
Snapshot (fast, blocking) — copy each rank's GPU
tensors into pinned (page-locked) host memory with
an async
cudaMemcpyAsyncon a dedicated CUDA stream. Training pauses only for this copy — sub-second per shard. The copy must be taken at a consistent cut: right afteroptimizer.step()and before the next forward pass mutates anything. - Flush (slow, background) — a background thread or sidecar process writes the staged host buffer out to durable storage, computes checksums, and publishes the checkpoint, all while the GPUs march on to the next steps. Training never waits for the storage write.
Double-buffering
The snapshot for step N must not be overwritten while it
is still being flushed. Use double-buffering: two
staging buffers (or copy-on-write) so the next snapshot can land in
buffer B while buffer A is still draining to storage. This bounds the
host-RAM cost to about shard_size x num_buffers and
removes the flush from the critical path entirely. Practical caveats:
- Host memory pressure — staging buffers consume RAM equal to the shard times the buffer count; bound the buffers and apply back-pressure if storage falls behind.
-
Consistency — never let the optimizer update a
tensor that is mid-copy; the snapshot stream must complete (or
copy-on-write) before the next
step()touches it. - Failure mid-flush — if a node dies while flushing, that checkpoint is simply never published; the previous good one still stands.
sequenceDiagram
participant T as Trainer (GPU)
participant H as Host buffer (pinned)
participant W as Background writer
participant S as Durable storage
T->>T: optimizer.step() at step N
T->>H: snapshot shard via cudaMemcpyAsync
H-->>T: copy complete (sub-second)
T->>T: resume training at step N+1
W->>H: read staged shard
W->>S: flush shard plus checksum
S-->>W: durable, then atomic publish
Note over T,S: training never waited for storage
The trainer pays only the fast host copy; the background writer drains that buffer to durable storage while training continues, so the multi-terabyte write is fully hidden behind compute.
The stall is the snapshot, not the write
Done well, training pauses only for the ~70 ms host
copy, while the ~28 s (or longer) durable write
amortizes invisibly in the background. That single decoupling is
what turns checkpoint overhead from a first-order MFU tax into a
rounding error — and it is the reason you can afford to checkpoint
frequently.
Tiered storage & fast restore
Not all failures are equal, and not all storage is equal — so checkpoints live in a hierarchy from fast-and-volatile to slow-and-durable. The goal is to make the common failure (a single node) recover from a fast tier, and reserve the slow durable tier for rare, correlated losses.
- GPU HBM — the live training state itself.
- Host RAM / peer GPU (in-memory checkpoint) — the snapshot target; fastest to write and to read back.
- Local NVMe SSD — node-local, survives a process crash, and absorbs frequent cheap saves at GB/s.
- Durable object store / parallel filesystem — survives node and rack loss; the home for periodic milestones.
In-memory redundancy
The key trick for fast restore is peer replication: as soon as a rank snapshots into host RAM, it also ships a copy to a neighbour rank's host memory (simple replication, or erasure-coded across a small group) over the fast InfiniBand fabric. Because the overwhelming majority of failures are single-node, a replacement worker can rehydrate that shard from the neighbour's RAM in tens of milliseconds — never reading from the slow object store at all. Only a correlated, multi-node loss has to fall back to the durable tier. This is the idea behind systems like CheckFreq and Gemini-style in-memory checkpointing.
flowchart LR
HBM["GPU HBM (live state)"] --> HOST["Host RAM (pinned snapshot)"]
HOST --> PEER["Peer host RAM (replica over IB)"]
HOST --> NVME["Local NVMe (fast, node-local)"]
NVME --> OBJ["Durable object store / parallel FS"]
PEER -. fast recover .-> HBM
OBJ -. durable recover .-> HBM
State flows down the tiers on write; on a single-node failure the replacement pulls its shard back from a peer's RAM (fast path), falling back to durable storage only for correlated losses.
Recover from a neighbour, not from disk
The durable tier is itself a hard design problem — an ML-optimized distributed filesystem tuned for huge parallel sequential checkpoint I/O (see the companion ML-Optimized Distributed File System design). Tiering exists precisely so you touch that slow tier as rarely as possible: most recoveries are served from peer RAM or local NVMe in well under a second.
Consistency
A checkpoint is only useful if every shard belongs to the same training step and the whole set is provably complete. Two failure modes to design out: shards from different steps stitched together, and a partially written checkpoint read back as if it were whole.
-
Globally consistent cut — coordinate the snapshot
with a barrier at step
N: no rank advances the model state toN+1until its snapshot copy is taken. Because the snapshot is a fast host copy, the barrier stall is tiny — you pay milliseconds for a globally aligned cut. -
Atomic publish — write all shards to an
uncommitted path, then commit by writing the manifest (or a
_COMPLETEmarker) last, via an atomic rename or a conditional put. Readers resolve a checkpoint only through that manifest, so they can never observe a half-written set. -
Versioning & retention — name checkpoints by
step (
ckpt-001000,ckpt-002000, ...); keep a rolling last-Kplus sparse milestones for the long haul, and advance a small last-known-good pointer only after verification. GC everything else. - Corruption detection — store a per-shard checksum (CRC32 / xxHash) in the manifest and verify it on read. This catches partial writes and silent data corruption (SDC); on a bad shard, fall back to the peer replica or the previous checkpoint.
- Failure during checkpoint — an unfinished checkpoint is never published, so a crash mid-flush simply leaves the prior good checkpoint in place. Atomicity guarantees there is no half-committed state to clean up.
Never publish a partial checkpoint
The invariant that makes everything else safe: a checkpoint becomes visible only when its manifest commits, and the manifest commits only after every shard is durable and checksummed. Restore always resolves the newest complete, verified checkpoint — so a process can die at any instant during a save without ever corrupting your recovery point.
Bottlenecks & scaling
As models and fleets grow, each defence above eventually becomes the next bottleneck. The recurring pattern: trade host memory and network for less stall and faster restore.
| Bottleneck | Symptom | Mitigation |
|---|---|---|
| Storage write bandwidth | Flush can't keep up; back-pressure stalls the next snapshot | Sharded parallel writes, more writers, de-dup DP replicas, faster durable tier |
| Stall time (snapshot) | Training pauses too long to copy state out | Async non-blocking snapshot to pinned host, separate stream, double-buffering |
| Restore time | Every failure does a slow cold read from the object store | Tiered peer-RAM / NVMe, in-memory redundancy, parallel reads, rehydrate only the failed shard |
| Metadata / small files | Millions of tiny objects throttle the store; index lookups crawl | Coalesce shards, one manifest, fewer larger objects, a metadata service |
| Host memory pressure | Staging buffers consume RAM (shard x buffers) | Bound buffer count, stream-write, spill to local NVMe, back-pressure |
| Checkpoint frequency | Too frequent = overhead; too rare = big redo on a crash | Young/Daly interval; frequent cheap-tier saves, sparse durable milestones |
Summary
Distributed checkpointing turns "a failure every 30 minutes" from a catastrophe into a few minutes of redo. The four pillars compose: snapshot asynchronously to pinned host memory so training stalls for milliseconds, not minutes; write sharded from every rank in parallel so a 14 TB state lands in seconds and can reshard on restore; stage across a tiered hierarchy with in-memory peer redundancy so the common single-node failure recovers from a neighbour in well under a second; and publish consistently with a barrier-aligned cut, atomic manifest, and per-shard checksums so you never resume from a partial or corrupt checkpoint. Get those four right and the most expensive fleet on earth spends its time training — not recovering.