System Design Notes All designs

AI / ML Infrastructure

Distributed Parameter Server

The parameter server (PS) splits a training job into two kinds of process: stateless workers that compute gradients on their slice of data, and stateful servers that hold the model's parameters, sharded across machines, and apply the optimizer update. Workers push gradients and pull the latest weights; servers own the source of truth. This decoupling is what lets you train a multi-terabyte recommendation model — billions of sparse embedding features that no single GPU can hold — across thousands of workers, and it is the natural home for asynchronous and elastic training. The whole design is a negotiation between three forces: throughput (workers never wait), convergence (updates can't be too stale), and network (a few servers must absorb the push/pull storm).

Requirements

A parameter server is fundamentally a distributed, sharded key-value store specialized for machine learning: the keys are parameter blocks (e.g. one embedding row per feature ID, or a slice of a dense weight matrix), the values are the weights plus their optimizer state, and the only write operation is "add this gradient and re-run the optimizer". Because the values are numeric and the writes are (approximately) commutative additions, the store can relax consistency far more aggressively than a general database — that relaxation is the entire performance story.

The paradigm

Functional requirements

Non-functional requirements

Where the parameter server shines (and where it doesn't)

PS is the right tool when the model is huge and sparse — recommendation and ads models with billion-scale embedding tables, where a single mini-batch touches only a tiny fraction of the parameters, so a worker pulls kilobytes-to-megabytes instead of terabytes. It also wins when you need asynchrony, elasticity, or heterogeneous hardware. For dense synchronous training — LLMs and CNNs where every parameter is updated on every step — all-reduce (covered in its own section) is almost always better, because there is no sparsity to exploit and a central server becomes a bottleneck.

Scale & back-of-the-envelope

Make it concrete with a production-style recsys / ads ranking model: a wide-and-deep architecture dominated by sparse embedding tables. The numbers below show why the model can't live on one machine, and why sparsity is the property that makes the push/pull traffic survivable.

Quantity Figure Math / why it matters
Sparse feature vocab ~10^10 feature IDs User IDs, item IDs, and cross features — billions of rows in the embedding tables.
Embedding dim 64 floats/row Each feature maps to a small dense vector; dim trades capacity vs cost.
Weight bytes ~2.6 TB 1e10 x 64 x 4 B (fp32). Already 30x a single 80 GB GPU.
+ optimizer state ~5 TB total Adagrad accumulator adds ~4 B/param; Adam (m+v) would add ~8 B/param.
Dense MLP params ~50 M Tiny next to the embeddings — but updated every step (no sparsity).
Workers ~1,000 CPU or GPU trainers, each on its own data shard.
Server shards ~200 5 TB / 200 ~ 25 GB of params per shard — comfortably in RAM.
Touched keys / worker / step ~25,000 Batch 256 x ~100 nonzero features, deduplicated — a sliver of 10^10.
Pull (or push) per worker / step ~6 MB 25k x 64 x 4 B. All-reduce would instead move the full 5 TB.
Per-worker param bandwidth ~250 MB/s ~6 MB push + ~6 MB pull at ~20 steps/s.
Aggregate to server fleet ~250 GB/s 1,000 workers x 250 MB/s, spread over 200 shards ~ 1.25 GB/s per shard.
RPC rate ~40k push+pull RPCs/s Batched per step; underneath, billions of key lookups/s across the fleet.
Checkpoint size ~5 TB Naive single-stream write = hours -> needs sharded, async I/O.

Sparsity is the whole point

A mini-batch touches ~25,000 of 10,000,000,000 embedding rows — about 0.00025% of the model. The parameter server moves only that touched slice (megabytes), while a dense all-reduce would synchronize the entire 5 TB every step. That ratio is why PS is the default for huge-embedding recsys/ads models and why the same design is a poor fit for a dense LLM where every parameter is live on every step.

High-level design

The fleet is two tiers plus a control plane. Workers read data shards, pull the parameters for their mini-batch, compute gradients, and push them back. Parameter servers are a key-partitioned store: each shard owns a disjoint key range, applies the optimizer on push, and answers pulls. A coordinator assigns key ranges, watches health, and drives checkpointing. Crucially, workers talk to servers but not to each other — there is no all-reduce ring.

flowchart LR
  D["Training data shards"] --> W0["Worker 0"]
  D --> W1["Worker 1"]
  D --> WN["Worker N"]
  subgraph WK["Worker fleet (stateless compute)"]
    W0
    W1
    WN
  end
  subgraph PS["Parameter servers (key-sharded store)"]
    S0["PS shard 0"]
    S1["PS shard 1"]
    SM["PS shard M"]
  end
  W0 -->|"push gradients"| S0
  W1 -->|"push gradients"| S1
  WN -->|"push gradients"| SM
  S0 -->|"pull params"| W0
  S1 -->|"pull params"| W1
  SM -->|"pull params"| WN
  S0 --> U0["Apply optimizer (Adam / Adagrad)"]
  S1 --> U1["Apply optimizer"]
  SM --> UM["Apply optimizer"]
  C["Coordinator / scheduler"] -.-> WK
  C -.-> PS
      

Workers pull the keys for their batch, compute gradients, and push them to the server shard that owns each key; the server applies the optimizer in place. A real worker fans out to many shards per step (one per key range it touches).

The worker loop

Why split state from compute

Synchronization: the core deep dive

The single most important design choice in a parameter server is when a worker is allowed to read parameters relative to other workers' writes. This is the synchronization model, and it directly trades statistical efficiency (steps to converge) against hardware efficiency (wall-clock per step) and straggler tolerance. There are three classic points on the spectrum: BSP, ASP, and SSP.

The three models

Model Barrier Convergence (per step) Throughput Straggler tolerance Use when
BSP (sync) Every step Best — equals large-batch SGD; reproducible Lowest — gated by slowest worker Poor — one straggler stalls all Homogeneous fleet; need stability/reproducibility
ASP (async) None Worst — stale gradients can slow or destabilize Highest — no waiting Excellent — slow workers don't block Huge sparse models; heterogeneous / preemptible workers
SSP (bounded) Only when bound s exceeded Tunable — near-BSP for small s High — blocks only the runaway worker Good — absorbs transient slowdowns Real-world middle ground; mild heterogeneity

The diagram below shows the asynchronous case: two workers pull the same version, but because there is no barrier, W1's gradient is applied after the weights have already advanced — it is stale by one update.

sequenceDiagram
  participant W0 as Worker 0
  participant W1 as Worker 1
  participant PS as Param Server (shard)
  W0->>PS: pull params (version 10)
  W1->>PS: pull params (version 10)
  Note over W0,W1: compute gradients independently
  W0->>PS: push grad g0
  PS->>PS: apply update (version 11)
  Note over PS: W1 still on v10 (stale)
  W1->>PS: push grad g1 (stale)
  PS->>PS: apply update (version 12)
  PS-->>W0: params (version 12)
  Note over W0,PS: no global barrier between workers
      

Asynchronous push/pull: W1 computed on version 10 but its gradient lands after version 11, so it is applied "stale". Bounded staleness (SSP) caps how far this gap may grow before a worker is forced to wait.

Why sparse models tolerate staleness well

Staleness hurts least exactly where the parameter server is used most. In a billion-feature embedding table, two workers usually touch disjoint rows, so their "concurrent" updates rarely conflict — an async push to feature A doesn't make a pull of feature B any more stale. This is why async/SSP is the norm for recsys/ads: the conflict surface is tiny. Dense layers are the opposite — every worker writes every element — so they are often kept closer to synchronous (or handled with all-reduce) even inside an otherwise async job. Common guardrails: staleness-aware learning-rate scaling (down-weight a gradient by how stale it is), gradient clipping, and bounded staleness to cap the worst case.

Sharding & consistency

The parameter store is partitioned so that no single server holds the whole model and so that push/pull load spreads across the fleet. How you partition, replicate, and compress determines both fault tolerance and the network bill.

Partitioning

flowchart TD
  K["Param key = hash(feature id)"] --> R{"Consistent hash ring"}
  R --> S0["PS shard 0"]
  R --> S1["PS shard 1"]
  R --> S2["PS shard 2"]
  S0 --> S0R["Replica 0 (follower)"]
  S1 --> S1R["Replica 1 (follower)"]
  S2 --> S2R["Replica 2 (follower)"]
  HOT["Hot key (popular embedding)"] -.->|"replicate + cache"| S1
      

Keys hash onto a ring of shards; each shard has a replica for fault tolerance. Hot keys get extra replicas or worker-side caching so one popular embedding doesn't overload a single shard.

Replication & consistency

Hot keys & skew

Real feature distributions are power-law: a few items (a viral video, a top advertiser, a celebrity user) appear in a huge fraction of examples, so their embedding rows are pulled and pushed far more than average and concentrate on one shard. Mitigations:

Cutting the network: gradient compression

The servers are an incast point — many workers, few shards — so shrinking each message directly raises the throughput ceiling.

Parameter server vs all-reduce

These are the two dominant paradigms for distributed training, and choosing correctly is a classic staff-level judgement call. All-reduce is decentralized: every worker holds a full replica of the model and they collectively sum gradients via a bandwidth-optimal ring or tree (NCCL), with no central server. Parameter server is centralized: a sharded store owns the weights and workers pull only what they need. They make opposite bets about sparsity, symmetry, and synchrony.

Dimension Parameter server All-reduce
Topology Centralized — workers talk to servers Decentralized — workers talk peer-to-peer (ring/tree)
Where the model lives Sharded across servers; can far exceed one node Full replica on every worker; must fit (or be sharded by FSDP)
Best for Huge sparse embeddings (recsys/ads) Dense models (LLMs, CNNs) — every param updated each step
Data moved / step Only the touched slice (KB-MB) Full gradient (~2x params), every step
Synchrony Sync, async, or bounded-stale — flexible Inherently synchronous (the collective is a barrier)
Elastic / heterogeneous Natural — stateless, preemptible workers Harder — ring must re-form; assumes homogeneous GPUs
Bottleneck Server incast, hot keys, staleness Slowest link / straggler; no central point but full sync

Choosing — and combining

The deciding question: how dense is a step?

If a single step touches a tiny fraction of parameters, the parameter server's "move only what you touch" wins decisively — all-reduce would waste bandwidth synchronizing untouched weights. If a step touches nearly all parameters, the server becomes a needless central choke point and all-reduce's peer-to-peer bandwidth optimality wins. Sparsity, not model size alone, is the real discriminator.

Bottlenecks & scaling

Scaling a parameter server is a march through a familiar set of pressure points. Each shows up as workers waiting (throughput drops) or the loss curve worsening (convergence drops); the mitigation usually trades one against the other.

Bottleneck Symptom Mitigation
Server network (incast) Many workers push/pull to few shards; NICs saturate More shards, gradient compression/quantization, worker-side caching, hierarchical aggregation
Hot keys / skew A few popular embeddings overload one shard Replicate hot keys, cache on workers, local accumulation, sub-shard the key's dims
Stragglers Slow worker stalls BSP; or emits very stale grads in ASP Async / SSP, backup workers, drop or down-weight late gradients
Staleness vs convergence Async is fast but needs more steps / can diverge Bounded staleness (SSP), staleness-aware LR, gradient clipping
Server fault tolerance A shard dies -> its slice of the model is lost Replication (primary-backup/chain), sharded async checkpoints, fast failover
Param store memory TB-scale tables exceed aggregate RAM More shards, host-memory / SSD tiering, feature hashing & pruning, smaller dim
Checkpoint I/O Writing multi-TB state stalls training Sharded parallel writes, async / in-memory snapshot, tuned interval

Summary

The parameter server splits training into stateless workers that push gradients and sharded servers that own and update the parameters — a decoupling that lets a multi-terabyte sparse model span a fleet while workers stay cheap and elastic. Its defining lever is the synchronization model: BSP for clean convergence, ASP for raw throughput, and SSP to dial bounded staleness in between — and sparse embeddings tolerate staleness because concurrent updates rarely collide. Consistent hashing makes the store elastic, replication + sharded checkpoints make it durable, and hot-key replication + gradient compression tame the incast at the servers. Reach for a parameter server when a step touches a tiny slice of a huge model; reach for all-reduce when training is dense and synchronous — and for recsys, do both: embeddings on the PS, dense layers on all-reduce.