System Design Notes All designs

AI / ML Infrastructure

Distributed Checkpointing System

A 26-day training run on ~100,000 GPUs hits a hardware fault roughly every 30 minutes. A checkpoint is the only thing standing between that fault and losing hours — or days — of compute, because in synchronous training a single dead rank stalls the entire job. The hard part is that the full training state is multi-terabyte, so the naive "pause everyone and write it to disk" costs minutes of stall every interval and quietly burns MFU. A good checkpointing system resolves that tension along four axes: snapshot asynchronously so training barely pauses, write it sharded so bandwidth scales with the fleet, stage it across a tiered hierarchy so the common failure recovers from a neighbour and not from cold storage, and publish it consistently so you never restore from a half-written mess.

Requirements

Checkpointing is failure insurance for a batch job. The premium is the overhead you pay to take a snapshot; the payout is that any crash only rolls you back to the last save instead of to step zero. At 100k-GPU scale the question is never whether to checkpoint but how cheaply and how often — because the fault rate is measured in minutes and a restart from scratch can mean a week of redone work.

The subtlety is that a resumable checkpoint is far more than the model weights. To continue a run deterministically you must capture the entire state that influences the next step:

Functional requirements

Non-functional requirements

Recovery is throughput

Every minute spent stalled for a snapshot, and every step redone after a crash, is a minute the world's most expensive hardware sits idle. So the two metrics that matter are checkpoint overhead (how long training waits to save) and recovery cost (snapshot interval lost on a crash, plus restore time). Everything below is an attack on one of those two numbers.

Scale & back-of-the-envelope

Sizing a checkpoint starts from one formula: checkpoint_bytes = (params + optimizer_state) x bytes_per_element. For mixed-precision Adam the persisted footprint is about 14 bytes per parameter — bf16 weights (2) + fp32 master copy (4) + Adam m (4) + Adam v (4). Gradients are transient and are not saved. From there everything else — write bandwidth, snapshot stall, and how often to checkpoint — falls out.

Quantity Figure Math / why it matters
Footprint / param ~14 bytes bf16 weights (2) + fp32 master (4) + Adam m (4) + Adam v (4); grads not saved.
1T-param checkpoint ~14 TB 1e12 x 14 = full training state for one save.
175B checkpoint ~2.45 TB 1.75e11 x 14 — still far past any single host's RAM.
7B checkpoint ~98 GB 7e9 x 14 — even "small" models are big on disk.
Naive single-stream write ~3.9 hours 14 TB / 1 GB/s — gather-to-rank-0 is hopeless on its own.
Sharded write ~28 s @ 500 GB/s Aggregate bandwidth scales with writer count until storage caps it.
Per-rank shard ~1.75 GB 14 TB / 8,000 writers (after de-duplicating DP replicas).
Snapshot stall (host copy) ~70 ms 1.75 GB / ~25 GB/s PCIe into pinned host RAM — all training waits for.
Fleet MTBF ~30 min node_MTBF / N at 100k GPUs (see the training-cluster design).
Optimal interval tau ~ sqrt(2 x C x M) Young/Daly; C = checkpoint cost, M = MTBF. C=30s, M=1800s -> ~5.5 min.
Min wasted fraction ~ sqrt(2 x C / M) Async shrinks C to seconds -> fewer % wasted and you can save more often.
Restore from peer RAM ~tens of ms 1.75 GB over ~400 Gb/s IB vs seconds-to-minutes from the object store.

How often to checkpoint

Checkpoint too rarely and a crash throws away a long stretch of compute; checkpoint too often and the saves themselves dominate. The classic Young/Daly result balances the two: the optimal interval is roughly tau ~ sqrt(2 x C x M), where C is the cost of taking one checkpoint and M is the mean time between failures. The unavoidable overhead at that optimum is about sqrt(2 x C / M). On a crash you lose, on average, half an interval of work (tau / 2) plus the restore time.

Why async changes the whole equation

The lever in Young/Daly is C. If checkpointing blocks training for the full multi-terabyte write, C is minutes and the optimum forces long intervals (more redo on a crash). If you make the on-critical-path cost just the fast host snapshot — tens of milliseconds — then C collapses to seconds, the optimal interval shrinks, and the wasted fraction drops from double digits to low single digits. Cheap checkpoints let you checkpoint often, which is what actually bounds your loss on a crash.

Asynchronous checkpointing: the core deep dive

A synchronous checkpoint blocks every rank while terabytes stream to durable storage — a multi-minute stall on every interval that translates directly into lost MFU. The decisive idea is to split the operation along its two very different speeds: a fast copy out of GPU memory, and a slow flush to storage that runs in the background while training continues.

Two phases

Double-buffering

The snapshot for step N must not be overwritten while it is still being flushed. Use double-buffering: two staging buffers (or copy-on-write) so the next snapshot can land in buffer B while buffer A is still draining to storage. This bounds the host-RAM cost to about shard_size x num_buffers and removes the flush from the critical path entirely. Practical caveats:

sequenceDiagram
  participant T as Trainer (GPU)
  participant H as Host buffer (pinned)
  participant W as Background writer
  participant S as Durable storage
  T->>T: optimizer.step() at step N
  T->>H: snapshot shard via cudaMemcpyAsync
  H-->>T: copy complete (sub-second)
  T->>T: resume training at step N+1
  W->>H: read staged shard
  W->>S: flush shard plus checksum
  S-->>W: durable, then atomic publish
  Note over T,S: training never waited for storage
      

The trainer pays only the fast host copy; the background writer drains that buffer to durable storage while training continues, so the multi-terabyte write is fully hidden behind compute.

The stall is the snapshot, not the write

Done well, training pauses only for the ~70 ms host copy, while the ~28 s (or longer) durable write amortizes invisibly in the background. That single decoupling is what turns checkpoint overhead from a first-order MFU tax into a rounding error — and it is the reason you can afford to checkpoint frequently.

Sharded & distributed checkpoints

The naive approach gathers every shard to rank 0 and writes one monolithic file. It fails three ways at scale: rank 0's network link and local disk become the bottleneck, the write is serial, and rank 0's host RAM cannot even hold a 14 TB state. The answer is SPMD writers: every rank writes its own shard in parallel, so aggregate bandwidth scales with the number of writers.

Resharding on restore

The trap is binding the on-disk layout to the runtime topology — then you can only restore onto the exact same TP x PP x DP shape. Instead, store a topology-agnostic logical layout: per tensor, record the global shape and each shard's (offset, extent) within it. On restore, each new rank computes the byte ranges it needs and reads exactly those — possibly spanning several shard files — then reassembles its local slice. This decoupling of storage layout from runtime layout (as in PyTorch Distributed Checkpoint) is what lets a job resume at a different world size after a failure or a deliberate rescale.

flowchart TD
  subgraph Ranks["Training ranks (SPMD writers)"]
    R0["Rank 0 writes shard 0"]
    R1["Rank 1 writes shard 1"]
    R2["Rank 2 writes shard 2"]
    RN["Rank N writes shard N"]
  end
  R0 --> OS["Object store / parallel FS"]
  R1 --> OS
  R2 --> OS
  RN --> OS
  META["Global index: tensor maps to shard plus offset"] --> OS
  OS --> RESHARD["Restore: each new rank reads only the byte ranges it needs (reshard)"]
      

Every rank writes its own shard in parallel into shared storage; a global index records where each tensor lives so restore can read arbitrary byte ranges and reshard onto a different topology.

Storage layout is not runtime layout

Persist tensors by their global coordinates, never by "rank 7's buffer". Get this right and a 1,024-GPU run can resume on 768 GPUs after a rack fails; get it wrong and every failure that changes the topology forces a full conversion pass before you can train again.

Tiered storage & fast restore

Not all failures are equal, and not all storage is equal — so checkpoints live in a hierarchy from fast-and-volatile to slow-and-durable. The goal is to make the common failure (a single node) recover from a fast tier, and reserve the slow durable tier for rare, correlated losses.

In-memory redundancy

The key trick for fast restore is peer replication: as soon as a rank snapshots into host RAM, it also ships a copy to a neighbour rank's host memory (simple replication, or erasure-coded across a small group) over the fast InfiniBand fabric. Because the overwhelming majority of failures are single-node, a replacement worker can rehydrate that shard from the neighbour's RAM in tens of milliseconds — never reading from the slow object store at all. Only a correlated, multi-node loss has to fall back to the durable tier. This is the idea behind systems like CheckFreq and Gemini-style in-memory checkpointing.

flowchart LR
  HBM["GPU HBM (live state)"] --> HOST["Host RAM (pinned snapshot)"]
  HOST --> PEER["Peer host RAM (replica over IB)"]
  HOST --> NVME["Local NVMe (fast, node-local)"]
  NVME --> OBJ["Durable object store / parallel FS"]
  PEER -. fast recover .-> HBM
  OBJ -. durable recover .-> HBM
      

State flows down the tiers on write; on a single-node failure the replacement pulls its shard back from a peer's RAM (fast path), falling back to durable storage only for correlated losses.

Recover from a neighbour, not from disk

The durable tier is itself a hard design problem — an ML-optimized distributed filesystem tuned for huge parallel sequential checkpoint I/O (see the companion ML-Optimized Distributed File System design). Tiering exists precisely so you touch that slow tier as rarely as possible: most recoveries are served from peer RAM or local NVMe in well under a second.

Consistency

A checkpoint is only useful if every shard belongs to the same training step and the whole set is provably complete. Two failure modes to design out: shards from different steps stitched together, and a partially written checkpoint read back as if it were whole.

Never publish a partial checkpoint

The invariant that makes everything else safe: a checkpoint becomes visible only when its manifest commits, and the manifest commits only after every shard is durable and checksummed. Restore always resolves the newest complete, verified checkpoint — so a process can die at any instant during a save without ever corrupting your recovery point.

Bottlenecks & scaling

As models and fleets grow, each defence above eventually becomes the next bottleneck. The recurring pattern: trade host memory and network for less stall and faster restore.

Bottleneck Symptom Mitigation
Storage write bandwidth Flush can't keep up; back-pressure stalls the next snapshot Sharded parallel writes, more writers, de-dup DP replicas, faster durable tier
Stall time (snapshot) Training pauses too long to copy state out Async non-blocking snapshot to pinned host, separate stream, double-buffering
Restore time Every failure does a slow cold read from the object store Tiered peer-RAM / NVMe, in-memory redundancy, parallel reads, rehydrate only the failed shard
Metadata / small files Millions of tiny objects throttle the store; index lookups crawl Coalesce shards, one manifest, fewer larger objects, a metadata service
Host memory pressure Staging buffers consume RAM (shard x buffers) Bound buffer count, stream-write, spill to local NVMe, back-pressure
Checkpoint frequency Too frequent = overhead; too rare = big redo on a crash Young/Daly interval; frequent cheap-tier saves, sparse durable milestones

Summary

Distributed checkpointing turns "a failure every 30 minutes" from a catastrophe into a few minutes of redo. The four pillars compose: snapshot asynchronously to pinned host memory so training stalls for milliseconds, not minutes; write sharded from every rank in parallel so a 14 TB state lands in seconds and can reshard on restore; stage across a tiered hierarchy with in-memory peer redundancy so the common single-node failure recovers from a neighbour in well under a second; and publish consistently with a barrier-aligned cut, atomic manifest, and per-shard checksums so you never resume from a partial or corrupt checkpoint. Get those four right and the most expensive fleet on earth spends its time training — not recovering.