Mooncake EP & Mooncake Backend (PG)#
Overview#
Mooncake provides two closely related components for fault-tolerant MoE inference:
Mooncake Backend (PG) is a
torch.distributedProcessGroup backend. It registers themooncakeaccelerator backend and themooncake-cpubackend, implements common collective and point-to-point APIs, tracks active ranks, and exposes elastic recovery helpers.Mooncake EP is an expert-parallel dispatch/combine runtime for latency-sensitive MoE inference. It follows the DeepEP low-latency programming model while adding rank activeness awareness and Mooncake transport support.
The usual integration pattern is to initialize a Mooncake process group first,
then construct a Mooncake EP Buffer from that group. The process group is used
both for regular collectives and for exchanging EP bootstrap metadata.
For implementation details, see the Mooncake Backend (PG) design guide and the Mooncake EP design guide.
Installation and build notes#
Mooncake EP and PG are included in CUDA-enabled Mooncake wheels. When building from source, enable the EP/PG extensions with:
cmake .. -DWITH_EP=ON
The extensions are compiled against a specific PyTorch version. At import time,
mooncake.pg and mooncake.ep load version-suffixed extension modules that
match the active torch.__version__. If the current PyTorch version does not
match a built extension, import will fail with a message such as
Mooncake PG was not built against torch==....
Mooncake Backend (PG) quick start#
CUDA backend#
import os
import torch
import torch.distributed as dist
from mooncake import pg
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
# Backend-level active-rank mask. Use int32 and place it on the backend device.
active_ranks = torch.ones(world_size, dtype=torch.int32, device=device)
dist.init_process_group(
backend="mooncake",
rank=rank,
world_size=world_size,
pg_options=pg.MooncakeBackendOptions(active_ranks),
)
x = torch.tensor([rank + 1], dtype=torch.int32, device=device)
dist.all_reduce(x, op=dist.ReduceOp.SUM)
print(f"rank={rank}, all_reduce={int(x.cpu())}")
Run it with the usual PyTorch launcher, for example:
torchrun --nproc-per-node=2 pg_quickstart.py
CPU backend#
Use backend="mooncake-cpu" and put active_ranks on CPU:
active_ranks = torch.ones(world_size, dtype=torch.int32)
dist.init_process_group(
backend="mooncake-cpu",
rank=rank,
world_size=world_size,
pg_options=pg.MooncakeBackendOptions(active_ranks),
)
Selecting network devices#
To explicitly restrict Mooncake to a list of NIC / HCA devices, call
pg.set_device_filter(...) before init_process_group():
from mooncake import pg
pg.set_device_filter(["mlx5_1", "mlx5_2"])
For test and benchmark commands, the same setting is commonly passed through
MOONCAKE_PGTEST_DEVICE_FILTERS=mlx5_1,mlx5_2.
Mooncake Backend (PG) API reference#
MooncakeBackendOptions#
pg.MooncakeBackendOptions(active_ranks)
pg.MooncakeBackendOptions(active_ranks, is_extension)
pg.MooncakeBackendOptions(active_ranks, is_extension, max_world_size)
Arguments:
active_ranks:torch.int32tensor used as the backend-level rank-health mask. Formooncake, it must be on the accelerator device; formooncake-cpu, it must be on CPU. Whenmax_world_sizeis set, size this tensor tomax_world_size, not the current visible world size.is_extension: set toTruefor a replacement or joining process that will enter an existing group throughjoin_group().max_world_size: optional upper bound for reserved rank slots. It lets healthy ranks reserve inactive future ranks while keepingdist.get_world_size()equal to the current active size.
Utility functions#
Function |
Purpose |
Notes |
|---|---|---|
|
Override the host IP used by the backend. |
Call before |
|
Restrict NIC/HCA selection. |
Call before |
|
Reuse an external |
The engine must outlive all process groups. |
|
Query topology-preferred HCA for a location. |
Useful for topology-aware placement/debugging. |
|
Return the backend active-rank tensor. |
Used by EP fallback and recovery paths. |
|
Return the number of ranks synchronized by the backend. |
Diagnostic helper. |
|
Reserve additional inactive ranks. |
Newly extended ranks do not participate until recovered. |
|
Check whether candidate ranks have published peer metadata. |
Collective among healthy ranks. |
|
Activate ready ranks and publish extension state. |
Requires peer metadata to be ready. |
|
Joiner-side blocking call for extension ranks. |
Used after |
Supported distributed operations#
Mooncake Backend implements the following torch.distributed APIs. Support may
depend on device type, dtype, PyTorch version, and whether the current backend is
mooncake or mooncake-cpu; run the PG tests on the target environment before
production use.
API family |
Examples |
Notes |
|---|---|---|
Collectives |
|
Active ranks participate; inactive ranks are skipped by backend internals. |
Async work |
|
Wait on the returned work object, then synchronize the device stream as needed. |
P2P |
|
Single-tensor P2P is routed through the Mooncake P2P shim. |
Elastic recovery protocol#
Mooncake PG supports a two-sided recovery protocol. Existing healthy ranks poll for replacement rank readiness, then activate those ranks. Replacement ranks start in extension mode, publish metadata, and wait until healthy ranks recover them.
Healthy-rank side#
from mooncake import pg
active_ranks = torch.tensor([1, 1, 0], dtype=torch.int32, device=device)
dist.init_process_group(
backend="mooncake",
rank=rank,
world_size=2,
pg_options=pg.MooncakeBackendOptions(
active_ranks,
False, # is_extension
3, # max_world_size
),
)
backend = dist.group.WORLD
join_ranks = [2]
while not all(pg.get_peer_state(backend, join_ranks)):
# Continue serving, back off, or poll according to your scheduler policy.
pass
pg.recover_ranks(backend, join_ranks)
Joining-rank side#
from mooncake import pg
active_ranks = torch.tensor([1, 1, 1], dtype=torch.int32, device=device)
dist.init_process_group(
backend="mooncake",
rank=2,
world_size=3,
pg_options=pg.MooncakeBackendOptions(
active_ranks,
True, # is_extension
3, # max_world_size
),
)
backend = dist.group.WORLD
pg.join_group(backend)
Important semantics:
get_peer_state()is collective among the current healthy ranks. Call it in a consistent order across those ranks.New ranks are inactive after
extend_group_size_to()and become collective participants only afterrecover_ranks().A joining rank initialized with
is_extension=Truestarts with local-only behavior and blocks injoin_group()until the corresponding healthy ranks publish recovery state.Subgroups must be created in the same order on healthy and joining processes, following PyTorch
new_group()ordering rules.
Mooncake EP quick start#
Mooncake EP exposes Buffer from mooncake.mooncake_ep_buffer. Initialize it
with a Mooncake process group and a workspace size computed from the expected
dispatch shape.
import torch
import torch.distributed as dist
from mooncake import pg
from mooncake.mooncake_ep_buffer import Buffer
# Assume dist.init_process_group(..., backend="mooncake", ...) has completed.
group = dist.group.WORLD
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
num_tokens = 128
hidden = 7168
num_experts = 288
top_k = 8
max_tokens_per_rank = 128
x = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device="cuda")
scores = torch.randn(num_tokens, num_experts, dtype=torch.float32, device="cuda")
topk_idx = torch.topk(scores, top_k, dim=-1).indices
topk_weights = torch.softmax(
torch.randn(num_tokens, top_k, dtype=torch.float32, device="cuda"), dim=-1
)
num_ep_buffer_bytes = Buffer.get_ep_buffer_size_hint(
max_tokens_per_rank,
hidden,
world_size,
num_experts,
)
buffer = Buffer(group, num_ep_buffer_bytes)
# EP-level rank-health tensor. Kernels may update it to 0 when timeout_us
# detects a failed source rank.
active_ranks = torch.ones(world_size, dtype=torch.int32, device="cuda")
recv_x, recv_count, handle, event, hook = buffer.dispatch(
x,
topk_idx,
active_ranks,
num_max_dispatch_tokens_per_rank=max_tokens_per_rank,
num_experts=num_experts,
timeout_us=-1,
use_fp8=True,
async_finish=False,
return_recv_hook=False,
)
event.current_stream_wait()
# Run local experts on recv_x here. If use_fp8=True, recv_x is a
# (data, scales) tuple; dequantize or feed it into an FP8-aware expert kernel.
expert_out = run_local_experts(recv_x, recv_count)
combined_x, event, hook = buffer.combine(
expert_out,
topk_idx,
topk_weights,
active_ranks,
timeout_us=-1,
handle=handle,
)
event.current_stream_wait()
Mooncake EP API reference#
Buffer.get_ep_buffer_size_hint(...)#
Buffer.get_ep_buffer_size_hint(
num_max_dispatch_tokens_per_rank: int,
hidden: int,
num_ranks: int,
num_experts: int,
) -> int
Returns the workspace size in bytes for the EP buffer. Use the maximum number of tokens a rank may dispatch in one step. Underestimating this value can cause buffer overflow or incorrect dispatch results.
Buffer(group, num_ep_buffer_bytes=0)#
Creates the EP runtime for a Mooncake process group. The constructor exchanges RDMA and IPC metadata through the group, initializes fast-path transports when available, and falls back to the Python implementation if the fast path is not usable.
Buffer.dispatch(...)#
recv_x, recv_count, handle, event, hook = buffer.dispatch(
x,
topk_idx,
active_ranks,
num_max_dispatch_tokens_per_rank,
num_experts,
timeout_us,
use_fp8=True,
async_finish=False,
return_recv_hook=False,
)
Arguments:
x: local token hidden states, shape[num_tokens, hidden], typically BF16 on CUDA.topk_idx: selected expert IDs, shape[num_tokens, top_k]. Use-1to mark masked selections.active_ranks: EP-level rank-health tensor, shape[num_ranks], dtypetorch.int32. Timeout detection may set failed source ranks to0.num_max_dispatch_tokens_per_rank: workspace capacity per rank. It should be at least the maximum localnum_tokensacross ranks for the current step.num_experts: global expert count. It must be divisible bynum_ranks.timeout_us: timeout in microseconds. Use-1to disable timeout detection.use_fp8: whenTrue, dispatch returns FP8 data plus scales.async_finish: whenTrue, returned tensors are associated with the returned event for stream-lifetime management.return_recv_hook: whenTrue, call the returnedhook()to complete receive synchronization; otherwise useevent.current_stream_wait().
Returns:
recv_x: packed local-expert inputs. Ifuse_fp8=True, this is(packed_data, packed_scales); otherwise it is a BF16 tensor.recv_count: number of tokens received by each local expert.handle: opaque metadata required bycombine()andget_next_combine_buffer().event:EventOverlaphelper; callevent.current_stream_wait()before using outputs when no hook is used.hook: optional synchronization hook used whenreturn_recv_hook=True.
Buffer.combine(...)#
combined_x, event, hook = buffer.combine(
x,
topk_idx,
topk_weights,
active_ranks,
timeout_us,
handle,
zero_copy=False,
async_finish=False,
return_recv_hook=False,
out=None,
)
Arguments:
x: local expert outputs packed in the layout returned bydispatch().topk_idxandtopk_weights: routing metadata for combining expert outputs back to local tokens.active_ranks: same EP-level rank-health tensor used bydispatch().timeout_us: timeout in microseconds; use-1to disable timeout detection.handle: the handle returned by the matchingdispatch()call.zero_copy: whenTrue, write expert outputs intobuffer.get_next_combine_buffer(handle)and pass that tensor tocombine().out: optional output tensor for the combined result.
Buffer.get_next_combine_buffer(handle)#
Returns the next combine buffer for zero-copy expert output. Use it only with the
matching dispatch handle and pass the resulting tensor back to combine() with
zero_copy=True.
Buffer.update_ep_member()#
Reconnects EP peers after backend membership changes. Call it after PG recovery updates rank activeness so EP transport metadata and QPs can be refreshed.
Active-rank tensors: PG vs EP#
There are two active-rank tensors in the API surface:
PG active-rank mask: passed to
pg.MooncakeBackendOptions. This is the backend-level health mask used by collective and recovery logic.EP active-rank tensor: passed to
Buffer.dispatch()andBuffer.combine(). It is also rank-level ([num_ranks],torch.int32) and may be updated by EP kernels when timeout detection marks a peer as failed.
In simple integrations these tensors often carry the same health information,
but they are passed through different API layers. Keep their dtype, device, and
shape consistent with the process group world size or reserved max_world_size.
Tests and examples#
PG collectives:
mooncake-pg/tests/test_pg_collectives.pyPG elastic recovery and subgroup extension:
mooncake-pg/tests/test_pg_elastic.pyPG benchmark harness:
mooncake-pg/benchmark/README.mdEP correctness and failure simulation:
mooncake-ep/tests/test_ep_grid.pyWheel-level EP example:
mooncake-wheel/tests/test_mooncake_ep.py
See PG/EP troubleshooting for common setup and runtime issues.