AI / ML Infrastructure
Distributed Parameter Server
The parameter server (PS) splits a training job into
two kinds of process: stateless workers that compute
gradients on their slice of data, and stateful
servers that hold the model's parameters, sharded
across machines, and apply the optimizer update. Workers
push gradients and pull the latest weights;
servers own the source of truth. This decoupling is what lets you
train a multi-terabyte recommendation model —
billions of sparse embedding features that no single GPU can hold —
across thousands of workers, and it is the natural home for
asynchronous and elastic training.
The whole design is a negotiation between three forces:
throughput (workers never wait),
convergence (updates can't be too stale), and
network (a few servers must absorb the push/pull storm).
Requirements
A parameter server is fundamentally a distributed, sharded key-value store specialized for machine learning: the keys are parameter blocks (e.g. one embedding row per feature ID, or a slice of a dense weight matrix), the values are the weights plus their optimizer state, and the only write operation is "add this gradient and re-run the optimizer". Because the values are numeric and the writes are (approximately) commutative additions, the store can relax consistency far more aggressively than a general database — that relaxation is the entire performance story.
The paradigm
-
Workers are stateless compute. Each pulls the
parameters it needs, runs forward + backward on a local mini-batch
to produce a
gradient, pushes that gradient to the servers, and repeats. A worker holds no authoritative state, so it can crash, restart, or be preempted at any time. -
Servers own the parameters. They are partitioned by
key, so shard
iholds a disjoint slice of the model. On receiving a gradient a server applies the optimizer update (SGD,Adagrad,Adam) in place and bumps the parameter's version. - Coordinator / scheduler tracks membership, assigns key ranges to servers, drives checkpointing, and re-shards on elastic join/leave.
Functional requirements
- push(gradients) — a worker sends a sparse or dense gradient for a set of keys; the server applies the configured optimizer and updates the weights.
- pull(keys) -> params — a worker fetches the current value of a set of keys before its next step. For sparse models it pulls only the keys in its mini-batch, not the whole model.
- Sharded key-value parameter store — parameters partitioned across many servers so the model can vastly exceed one machine's memory, and throughput scales with the number of shards.
- Sparse and dense parameters — support both giant sparse embedding tables (billions of rows, only a handful touched per step) and small dense layers (every element updated every step). They have opposite access patterns and are often handled differently.
- Server-side optimizer — the update rule and its per-parameter state (momentum, Adagrad accumulators) live on the server, so workers only ship raw gradients.
- Checkpoint / restore — persist the sharded parameters + optimizer state so a run survives failure and can resume from the last good step.
Non-functional requirements
- High throughput — the metric is global samples/second. Workers should spend their time computing, not blocked on the network or on each other.
- Independent scalability — add servers to grow model capacity and push/pull bandwidth; add workers to grow compute. The two axes scale separately, unlike all-reduce where every node holds the full model.
- Fault tolerance — with thousands of machines, failure is routine. A dead worker must not stall the job; a dead server must not lose its shard of the model.
- Bounded staleness / tunable consistency — asynchronous updates are fast but read slightly stale weights. The system must let you trade convergence quality against throughput by bounding how stale a read may be.
- Elasticity — tolerate heterogeneous and preemptible workers (spot CPUs/GPUs) joining and leaving without a full restart.
Where the parameter server shines (and where it doesn't)
PS is the right tool when the model is huge and sparse — recommendation and ads models with billion-scale embedding tables, where a single mini-batch touches only a tiny fraction of the parameters, so a worker pulls kilobytes-to-megabytes instead of terabytes. It also wins when you need asynchrony, elasticity, or heterogeneous hardware. For dense synchronous training — LLMs and CNNs where every parameter is updated on every step — all-reduce (covered in its own section) is almost always better, because there is no sparsity to exploit and a central server becomes a bottleneck.
Scale & back-of-the-envelope
Make it concrete with a production-style recsys / ads ranking model: a wide-and-deep architecture dominated by sparse embedding tables. The numbers below show why the model can't live on one machine, and why sparsity is the property that makes the push/pull traffic survivable.
| Quantity | Figure | Math / why it matters |
|---|---|---|
| Sparse feature vocab | ~10^10 feature IDs |
User IDs, item IDs, and cross features — billions of rows in the embedding tables. |
| Embedding dim | 64 floats/row |
Each feature maps to a small dense vector; dim trades capacity vs cost. |
| Weight bytes | ~2.6 TB |
1e10 x 64 x 4 B (fp32). Already
30x a single 80 GB GPU.
|
| + optimizer state | ~5 TB total |
Adagrad accumulator adds ~4 B/param; Adam (m+v) would add ~8 B/param. |
| Dense MLP params | ~50 M |
Tiny next to the embeddings — but updated every step (no sparsity). |
| Workers | ~1,000 |
CPU or GPU trainers, each on its own data shard. |
| Server shards | ~200 |
5 TB / 200 ~ 25 GB of params per shard —
comfortably in RAM.
|
| Touched keys / worker / step | ~25,000 |
Batch 256 x ~100 nonzero features, deduplicated — a sliver of 10^10. |
| Pull (or push) per worker / step | ~6 MB |
25k x 64 x 4 B. All-reduce would instead move the
full 5 TB.
|
| Per-worker param bandwidth | ~250 MB/s |
~6 MB push + ~6 MB pull at ~20 steps/s. |
| Aggregate to server fleet | ~250 GB/s |
1,000 workers x 250 MB/s, spread over 200 shards ~
1.25 GB/s per shard.
|
| RPC rate | ~40k push+pull RPCs/s |
Batched per step; underneath, billions of key lookups/s across the fleet. |
| Checkpoint size | ~5 TB |
Naive single-stream write = hours -> needs sharded, async I/O. |
Sparsity is the whole point
A mini-batch touches ~25,000 of
10,000,000,000 embedding rows — about
0.00025% of the model. The parameter server moves
only that touched slice (megabytes), while a dense all-reduce would
synchronize the entire 5 TB every step. That ratio is
why PS is the default for huge-embedding recsys/ads models and why
the same design is a poor fit for a dense LLM where every parameter
is live on every step.
High-level design
The fleet is two tiers plus a control plane. Workers read data shards, pull the parameters for their mini-batch, compute gradients, and push them back. Parameter servers are a key-partitioned store: each shard owns a disjoint key range, applies the optimizer on push, and answers pulls. A coordinator assigns key ranges, watches health, and drives checkpointing. Crucially, workers talk to servers but not to each other — there is no all-reduce ring.
flowchart LR
D["Training data shards"] --> W0["Worker 0"]
D --> W1["Worker 1"]
D --> WN["Worker N"]
subgraph WK["Worker fleet (stateless compute)"]
W0
W1
WN
end
subgraph PS["Parameter servers (key-sharded store)"]
S0["PS shard 0"]
S1["PS shard 1"]
SM["PS shard M"]
end
W0 -->|"push gradients"| S0
W1 -->|"push gradients"| S1
WN -->|"push gradients"| SM
S0 -->|"pull params"| W0
S1 -->|"pull params"| W1
SM -->|"pull params"| WN
S0 --> U0["Apply optimizer (Adam / Adagrad)"]
S1 --> U1["Apply optimizer"]
SM --> UM["Apply optimizer"]
C["Coordinator / scheduler"] -.-> WK
C -.-> PS
Workers pull the keys for their batch, compute gradients, and push them to the server shard that owns each key; the server applies the optimizer in place. A real worker fans out to many shards per step (one per key range it touches).
The worker loop
- 1. Pull — gather the parameter rows for the keys in the next mini-batch (sparse) plus all dense weights.
- 2. Compute — forward + backward pass to produce gradients for exactly those keys.
- 3. Push — send each key's gradient to the shard that owns it; the server applies the optimizer and bumps the version.
-
4. Repeat — fetch the next batch. Pull of step
t+1can overlap the compute of steptto hide latency.
Why split state from compute
- Capacity — the model lives in the aggregate RAM of the server fleet, so it can dwarf any single worker's memory.
- Elastic, cheap workers — because workers are stateless they can run on preemptible/spot machines and be added or removed freely.
- Independent scaling — scale servers for model size and bandwidth; scale workers for throughput. Each knob moves on its own.
- Asynchrony — a central store can accept updates as they arrive, with no global barrier — the foundation of the next section.
Synchronization: the core deep dive
The single most important design choice in a parameter server is
when a worker is allowed to read parameters relative to other
workers' writes. This is the synchronization model, and it directly trades
statistical efficiency (steps to converge) against
hardware efficiency (wall-clock per step) and
straggler tolerance. There are three classic points
on the spectrum: BSP, ASP, and
SSP.
The three models
- BSP — Bulk Synchronous Parallel (sync SGD). A global barrier every step: all workers pull the same version, compute on it, push, and only then advance. The aggregate update is mathematically identical to large-batch SGD, so convergence is the cleanest and most reproducible. The cost: the step runs at the speed of the slowest worker — one straggler stalls everyone.
-
ASP — Asynchronous Parallel (async SGD). No barrier
at all. Each worker pulls the latest weights, computes, and pushes
whenever it is ready; the server applies gradients in arrival order.
Throughput is maximal and stragglers are irrelevant, but a gradient
computed on version
10may land after the weights have already moved to version14— a stale gradient that points in a slightly wrong direction and can slow or destabilize convergence. -
SSP — Stale Synchronous Parallel (bounded staleness). The pragmatic middle ground: workers run asynchronously
but the fastest worker may be at most
ssteps ahead of the slowest. If it tries to pull while further ahead, it blocks until the laggard catches up. With smallsyou get near-BSP convergence; with largersyou get near-ASP throughput — a single tunable dial.
| Model | Barrier | Convergence (per step) | Throughput | Straggler tolerance | Use when |
|---|---|---|---|---|---|
| BSP (sync) | Every step | Best — equals large-batch SGD; reproducible | Lowest — gated by slowest worker | Poor — one straggler stalls all | Homogeneous fleet; need stability/reproducibility |
| ASP (async) | None | Worst — stale gradients can slow or destabilize | Highest — no waiting | Excellent — slow workers don't block | Huge sparse models; heterogeneous / preemptible workers |
| SSP (bounded) | Only when bound s exceeded |
Tunable — near-BSP for small s |
High — blocks only the runaway worker | Good — absorbs transient slowdowns | Real-world middle ground; mild heterogeneity |
The diagram below shows the asynchronous case: two workers pull the
same version, but because there is no barrier, W1's
gradient is applied after the weights have already advanced —
it is stale by one update.
sequenceDiagram
participant W0 as Worker 0
participant W1 as Worker 1
participant PS as Param Server (shard)
W0->>PS: pull params (version 10)
W1->>PS: pull params (version 10)
Note over W0,W1: compute gradients independently
W0->>PS: push grad g0
PS->>PS: apply update (version 11)
Note over PS: W1 still on v10 (stale)
W1->>PS: push grad g1 (stale)
PS->>PS: apply update (version 12)
PS-->>W0: params (version 12)
Note over W0,PS: no global barrier between workers
Asynchronous push/pull: W1 computed on version 10 but its
gradient lands after version 11, so it is applied "stale". Bounded
staleness (SSP) caps how far this gap may grow before a worker is
forced to wait.
Why sparse models tolerate staleness well
Staleness hurts least exactly where the parameter server is used most.
In a billion-feature embedding table, two workers usually touch
disjoint rows, so their "concurrent" updates rarely
conflict — an async push to feature A doesn't make a pull
of feature B any more stale. This is why async/SSP is the
norm for recsys/ads: the conflict surface is tiny. Dense layers are
the opposite — every worker writes every element — so they are often
kept closer to synchronous (or handled with all-reduce) even inside an
otherwise async job. Common guardrails:
staleness-aware learning-rate scaling (down-weight a
gradient by how stale it is), gradient clipping, and
bounded staleness to cap the worst case.
Parameter server vs all-reduce
These are the two dominant paradigms for distributed training, and choosing correctly is a classic staff-level judgement call. All-reduce is decentralized: every worker holds a full replica of the model and they collectively sum gradients via a bandwidth-optimal ring or tree (NCCL), with no central server. Parameter server is centralized: a sharded store owns the weights and workers pull only what they need. They make opposite bets about sparsity, symmetry, and synchrony.
| Dimension | Parameter server | All-reduce |
|---|---|---|
| Topology | Centralized — workers talk to servers | Decentralized — workers talk peer-to-peer (ring/tree) |
| Where the model lives | Sharded across servers; can far exceed one node | Full replica on every worker; must fit (or be sharded by FSDP) |
| Best for | Huge sparse embeddings (recsys/ads) | Dense models (LLMs, CNNs) — every param updated each step |
| Data moved / step | Only the touched slice (KB-MB) | Full gradient (~2x params), every step |
| Synchrony | Sync, async, or bounded-stale — flexible | Inherently synchronous (the collective is a barrier) |
| Elastic / heterogeneous | Natural — stateless, preemptible workers | Harder — ring must re-form; assumes homogeneous GPUs |
| Bottleneck | Server incast, hot keys, staleness | Slowest link / straggler; no central point but full sync |
Choosing — and combining
- Use a parameter server when the model is sparse and bigger than a worker (billion-feature embeddings), when you need asynchrony or bounded staleness, or when the fleet is elastic/heterogeneous (preemptible CPUs+GPUs).
- Use all-reduce when the model is dense and every parameter is updated each step (LLMs, vision), the fleet is a homogeneous GPU cluster, and synchronous SGD is desired — its ring algorithm is bandwidth-optimal and has no central bottleneck.
- Hybrid is the production answer for recsys. A model like DLRM splits cleanly: keep the giant sparse embedding tables on a parameter server (move only touched rows) and synchronize the small dense MLP with all-reduce (every element is live anyway). Each half runs on the paradigm that fits its access pattern.
The deciding question: how dense is a step?
If a single step touches a tiny fraction of parameters, the parameter server's "move only what you touch" wins decisively — all-reduce would waste bandwidth synchronizing untouched weights. If a step touches nearly all parameters, the server becomes a needless central choke point and all-reduce's peer-to-peer bandwidth optimality wins. Sparsity, not model size alone, is the real discriminator.
Bottlenecks & scaling
Scaling a parameter server is a march through a familiar set of pressure points. Each shows up as workers waiting (throughput drops) or the loss curve worsening (convergence drops); the mitigation usually trades one against the other.
| Bottleneck | Symptom | Mitigation |
|---|---|---|
| Server network (incast) | Many workers push/pull to few shards; NICs saturate | More shards, gradient compression/quantization, worker-side caching, hierarchical aggregation |
| Hot keys / skew | A few popular embeddings overload one shard | Replicate hot keys, cache on workers, local accumulation, sub-shard the key's dims |
| Stragglers | Slow worker stalls BSP; or emits very stale grads in ASP | Async / SSP, backup workers, drop or down-weight late gradients |
| Staleness vs convergence | Async is fast but needs more steps / can diverge | Bounded staleness (SSP), staleness-aware LR, gradient clipping |
| Server fault tolerance | A shard dies -> its slice of the model is lost | Replication (primary-backup/chain), sharded async checkpoints, fast failover |
| Param store memory | TB-scale tables exceed aggregate RAM | More shards, host-memory / SSD tiering, feature hashing & pruning, smaller dim |
| Checkpoint I/O | Writing multi-TB state stalls training | Sharded parallel writes, async / in-memory snapshot, tuned interval |
Summary
The parameter server splits training into
stateless workers that push gradients and
sharded servers that own and update the parameters
— a decoupling that lets a
multi-terabyte sparse model
span a fleet while workers stay cheap and elastic. Its defining
lever is the
synchronization model: BSP for clean
convergence, ASP for raw throughput, and
SSP to dial bounded staleness in between — and sparse
embeddings tolerate staleness because concurrent updates rarely
collide. Consistent hashing makes the store
elastic, replication + sharded checkpoints make it
durable, and
hot-key replication + gradient compression tame the
incast at the servers. Reach for a parameter server when a step
touches a tiny slice of a huge model; reach for
all-reduce when training is dense and synchronous —
and for recsys, do both: embeddings on the PS,
dense layers on all-reduce.