Speculative Decoding in the Wild
Eagle speculation predicts what your target model will say, then verifies all its guesses in one forward pass. The engineering challenges: building tree-shaped attention masks, juggling KV cache across speculation branches, and making CUDA graphs work when speculation depth varies. vLLM does this in Python (flexible, slower), SGLang bakes it into the model worker with CUDA graphs (fast, less flexible), TensorRT-LLM compiles the whole loop into one engine (fastest, requires rebuild to change anything).
The problem isn’t your code. It’s physics. Autoregressive generation is sequential: each token depends on all previous tokens. Generate token 1, wait. Generate token 2, wait. The GPU spends most of its time reading the KV cache from memory, not doing math. You bought a 2 petaFLOP machine and you’re using it as a very expensive memory controller.
Eagle speculative decoding works. The research has evolved through three generations: from feature-level drafting with static trees to dynamic context-aware trees to training-time test with multi-layer fusion, each achieving higher speedups with lower overhead. The math is clean. But papers don’t serve production traffic.
This post traces how vLLM, SGLang, and TensorRT-LLM actually implement Eagle. We’ll follow the code paths: how they build tree attention masks, manage KV cache for speculative branches, handle rejected tokens, and optimize with CUDA graphs. We’ll also cover the parts the papers gloss over: what happens when you fine-tune your model and your Eagle head stops working, why LoRA adapters break speculation, and when you should fall back to vanilla draft models.
I spent a few weeks reading through these codebases. The goal: understand not just what Eagle does, but how it’s built, why the frameworks made different implementation choices, and what breaks in production.
The Decode Bottleneck
Here’s what that looks like concretely. At batch size 1, decoding a single token from Llama-70B requires:
- Reading ~140GB of model weights from HBM
- Reading the KV cache (grows with sequence length)
- Performing a tiny amount of compute: one matrix-vector multiply per layer
The arithmetic intensity (ratio of compute to memory access) is pathetic. This is exactly the roofline model (Williams, Waterman, Patterson 2009). Autoregressive decode sits deep in the memory-bound region, far left of the ridge point. Prefill, by contrast, is compute-bound. Speculative decoding moves decode closer to prefill on the roofline by increasing the arithmetic intensity of each kernel launch. You’re doing maybe 2 FLOPs per byte read. The H100’s HBM can deliver 3.35 TB/s, so you’re bottlenecked at roughly 6.7 TFLOPS. The card can do ~2,000 TFLOPS at FP16. You’re at 0.34% utilization.
Speculative decoding attacks this by parallelizing verification. You can’t generate tokens in parallel (each depends on previous ones), but you can verify them in parallel. Guess K tokens, verify all K in one forward pass. If your guesses are good, you’ve generated K tokens for the cost of ~1.
Vanilla Draft Model Speculation
The original approach is simple.
Setup: You have a large target model (say, 70B parameters) and a small draft model (say, 7B parameters). The draft model is 10x smaller, so it runs 10x faster per token. This 10x rule only holds at batch size 1 where you’re memory-bandwidth-bound and runtime scales with parameter count. At higher batch sizes, decoding shifts toward compute-bound and the draft speedup shrinks, which is also why speculation benefits diminish at large batch sizes.
Algorithm:
- The draft model generates candidate tokens autoregressively (fast)
- The target model processes all tokens in one forward pass (parallel)
- For each draft token sampled from the draft distribution : draw and accept when falls below , where is the target model’s distribution
- On first rejection, resample from the adjusted distribution — this recovers exactly the probability mass the draft model over-allocated The proof follows from standard rejection sampling. When , the token is always accepted. When exceeds , acceptance probability is and the residual distribution covers exactly the remaining mass. In expectation, the combined process samples from .
- The output distribution is provably identical to running the target model alone. You’re trading draft compute for verification parallelism without changing a single output probability
Why it works — bandwidth, not compute: Single-token decode reads all model weights from HBM once and does one matrix-vector multiply per layer — an arithmetic intensity of roughly 2 FLOPs per byte, deep in memory-bound territory (recall the H100 sitting at 6.7 vs. 2,000 TFLOPS from the section above). Verifying draft tokens reads those same weights once but does multiplies, pushing arithmetic intensity to ~ FLOPs/byte. You’re spending compute that was already paid for in memory bandwidth — the ALUs were idle anyway. This is why the target’s forward pass over tokens costs almost the same wall-clock time as generating 1 token.
If the draft model has high acceptance rate and you draft tokens, the expected number of accepted tokens is . This formula from Leviathan et al. assumes i.i.d. acceptance at each position, which doesn’t hold in practice. The formula you’ll see elsewhere is the limit as ; for finite draft lengths it overestimates. With and , the exact formula gives tokens, not . Acceptance probability also decays with position because errors compound through the autoregressive chain. Eagle 2’s dynamic tree width is precisely an adaptation to this per-position variance.
The math (simplified): If and , you generate about 3.7 tokens per target call. If the draft model adds 10% overhead, you get x speedup. In practice, acceptance rates vary by position and context, so real speedups are typically 2-4x for well-matched draft models. The original speculative decoding paper (and the independent Chen et al. paper arriving at the same technique) along with Eagle benchmarks show 2.5-4x on various tasks.
Critical batch size: The bandwidth→compute argument has a ceiling. At batch size , even without speculation, arithmetic intensity is already ~ FLOPs/byte — you’re doing multiplies per weight read. The critical batch size is where standard decode crosses the roofline ridge point and becomes compute-bound. On an H100, the theoretical ridge point is around (2,000 TFLOPS / 3.35 TB/s), but in practice the crossover happens much earlier because KV cache reads, attention computation, and quantization all consume additional bandwidth. Real-world speculation benefits diminish above batch size ~8-16 for vanilla draft models, because at that point the verification pass is no longer “free” bandwidth you were wasting — the GPU is already busy. This is why production serving systems that batch aggressively (high-throughput scenarios) often disable speculation entirely. The speedup is a latency optimization for the memory-bound regime. When you’re saturating compute with large batches, speculation just adds overhead. Each Eagle variant pushes this critical batch size differently, as we’ll see.
The cost of a good draft model: Where do you get one? You can’t just grab a smaller model off the shelf. A well-matched draft model requires pretraining on large datasets (CommonCrawl-scale) and then distillation from the target model to align distributions. That’s months of work per target model, and every time the target gets updated you’re back to square one. The time-to-market cost is brutal.
I led my team to build one of the earliest custom-trained speculation implementations, covering both Llama 2 and Llama 3. We were one of the fastest out there for about nine months. But every new target checkpoint meant retraining the draft model, re-validating acceptance rates, and re-tuning speculation depth. If the draft diverges from the target, acceptance rates plummet and you’re worse off than vanilla decoding.
Eagle 1: Feature-Level Drafting
Eagle 1 noticed something: you don’t need a separate draft model. The target model already contains a compressed representation of its own behavior in its hidden states.
Key insight: The target’s second-to-top-layer features (the hidden states just before the LM head) predict what the target will output. Why the second-to-top layer? The residual stream accumulates information hierarchically: early layers capture syntax, late layers capture generation intent. The second-to-top-layer features are the most “decision-ready” representation: they’re the direct input to the LM head. Eagle 3 goes further by fusing features from multiple layers (the combine_hidden_states call does this), improving draft accuracy by combining low-, mid-, and high-level representations. Instead of training a full draft model, train a lightweight head that predicts the next hidden state from the current one. Medusa (Cai et al., 2024) takes a related but different path: it adds multiple LM heads that directly predict tokens at future positions (t+1, t+2, …, t+k) from the current hidden state, all in parallel. The tradeoff: Medusa heads don’t condition on their own previous predictions, so they miss sequential correlations, making them less accurate at distant positions. Eagle’s autoregressive drafting captures these correlations at the cost of sequential draft generation.
Architecture:
Target model: embeddings → layers 1...N → LM head → logits
Eagle draft head: [feature[t], embed[t+1]] → FC layer → transformer decoder layer → feature[t+1]_draft → shared LM head → logits_draft The Eagle head takes two inputs: the target’s second-to-top-layer feature at position and the token embedding of the sampled token at position (what the paper calls “the token shifted by one time step”). These are concatenated and passed through an FC layer that reduces dimensionality, then a single transformer decoder layer that predicts the next feature. The predicted feature is passed through the target’s shared LM head to produce draft logits. Why the shifted token? Feature-level autoregression faces an inherent uncertainty: the same hidden state could lead to different sampled tokens. By feeding the embedding of the actually-sampled token as input, Eagle resolves this ambiguity. During multi-step drafting, the draft head uses its own previously sampled tokens as the shifted input for subsequent steps.
The payoffs:
No distribution shift. The LM head is shared, so the draft logits are in the same space as the target’s. Acceptance rates are naturally high.
Minimal parameters. The Eagle head ranges from 240M parameters (7B target) to 990M (70B target). Training is cheap, inference overhead is negligible.
Feature reuse. You’re not running a separate model. You’re extending the target’s own computation graph. Memory overhead is minimal.
Eagle 1 uses static tree-structured speculation: a fixed tree of candidate tokens is generated and verified in a single target forward pass using tree attention. On LLaMA2-Chat 70B, this achieves ~3.0-3.5x speedup with acceptance rates around 75%, outperforming vanilla draft model approaches that used 7B draft models.
The LM head sharing isn’t just weight tying for memory savings. It’s the reason Eagle’s acceptance rate is high. The draft head predicts a hidden state, but that hidden state is only useful because it gets mapped to logits through the same LM head the target uses. If the heads differed, you’d reintroduce the distribution shift problem that plagues vanilla draft models. How each framework implements this sharing is covered in the framework sections below.
Training: The Eagle head is trained by freezing the target model and collecting hidden state trajectories over a training dataset (the original paper uses ShareGPT 68K conversations). The FC layer and decoder layer are then trained with a combined loss: feature regression (MSE between predicted and actual hidden states) and token prediction (cross-entropy on next tokens). The bottleneck is the initial collection pass — running the full target model over the training set to harvest hidden states. After that, training the lightweight head takes hours, not days, on a single node. A head that costs a fraction of the target model to train, capturing enough of its behavior for 75%+ acceptance rates.
Critical batch size shift: Eagle’s tiny draft overhead (240M-990M parameters vs. a 7B vanilla draft model) means the critical batch size where speculation stops helping is higher than with vanilla speculation. The draft phase consumes negligible memory bandwidth relative to the target forward pass, so the incremental arithmetic intensity from speculation barely moves you along the roofline. Where a 7B draft model might push you compute-bound at batch size 8, an Eagle head extends the useful range to ~16-32.
Eagle 2: Dynamic Draft Trees
Eagle 1 already uses tree-structured speculation with static trees: a fixed tree shape chosen offline. But a static tree wastes compute — token acceptance rates vary by context, not just position. A confident continuation needs one branch; an uncertain one needs several.
Eagle 2 makes the tree dynamic. No additional training is required; it reuses Eagle 1’s trained draft head and adds a training-free tree construction algorithm on top.
The key finding: Eagle’s draft model is well-calibrated. Its confidence scores (softmax probabilities) closely approximate actual acceptance rates. Tokens with draft confidence below 0.05 have ~4% acceptance rate; tokens above 0.95 have ~98%. This calibration enables context-aware tree construction.
Dynamic tree construction uses two phases:
Expansion: Select the top- nodes with highest global acceptance probability — the product of confidence scores along the path from root to node. Why the product? A token is accepted only if all its prefix tokens are also accepted. A deep node with 90% local confidence but a 50%-confidence ancestor has only 45% global acceptance probability. The product captures this cascading dependency. This is why Eagle 2 allocates more width at uncertain positions: a single low-confidence node in the path kills all its descendants. Expand these nodes via the draft model to generate the next layer.
Reranking: After expansion, globally sort all candidate tokens by their global acceptance probability. Select the top- tokens that maintain tree connectivity. This ensures the compute budget goes to the most promising paths regardless of depth.
Tree attention: The challenge is efficient verification. You can’t just run forward passes; that defeats the purpose. Eagle 2 uses masked attention so the target model processes the entire tree in one pass, with each node attending only to its ancestors. Standard causal attention is the special case where the tree is a single chain. Tree attention generalizes this: the mask is still lower-triangular in topological ordering, but with branching. Flatten the tree, build the mask from ancestor relationships, and the existing FlashAttention kernel works unmodified. The innovation is in building and packing the mask efficiently, not in the attention math itself.
Results: Eagle 2 achieves 3.05x-4.26x speedup, outperforming Eagle 1 by 20-40% on the same models. The dynamic tree is especially helpful for uncertain generations (creative writing, code completion with multiple valid paths).
The dynamic trees also extend speculation’s useful range at moderate batch sizes. When acceptance is high, Eagle 2 allocates fewer total tree nodes (the confident path doesn’t need width), reducing wasted verification compute. At batch size 8-16 where vanilla speculation is marginal, this tighter compute budget can keep speculation net-positive.
Eagle 3: Training-Time Test
Eagle 1 and 2 train the draft head with a feature regression loss: predict the target model’s hidden states as accurately as possible. This works, but the paper identifies a scaling bottleneck — increasing training data provides diminishing returns. The feature prediction constraint limits the draft model’s expressiveness.
Eagle 3 addresses this with three architectural changes that remove the bottleneck and let performance scale with data.
1. Remove the feature regression loss. Eagle ½ trained with (feature matching plus token prediction). Eagle 3 drops entirely and trains only on (cross-entropy on next tokens). The draft head no longer needs to match hidden states; it just needs to produce tokens the target will accept. The draft head can learn intermediate representations optimized for acceptance, not for mimicking a specific layer. Feature regression is a means to an end, not the end itself. Forcing the draft head to match hidden states exactly is an overly restrictive proxy for what actually matters (token acceptance). Removing it lets the draft head learn features that correlate with acceptance rather than features that minimize MSE against a specific layer’s activations.
2. Multi-layer feature fusion. Eagle ½ use only the second-to-top-layer features. Eagle 3 extracts features from three levels of the target model — low (syntax/structure), mid (relationships), and high (semantic/generation intent). These are concatenated into a -dimensional vector ( = hidden size) and compressed through an FC layer back to dimensions:
This provides richer input for multi-step prediction than a single layer optimized for single-token prediction. In the frameworks, this is the combine_hidden_states call you see in the vLLM Eagle 3 code path.
3. Training-time test. During training, Eagle 3 simulates multi-step autoregressive drafting. At step 1, the draft head receives ground-truth target features and produces output . At step 2 and beyond, it receives its own previous output as input and produces , and so on. Custom causal attention masks control what each position can attend to, preventing simulated positions from seeing ground-truth features that wouldn’t be available at inference time.
Why this matters: Eagle ½ train with ground-truth features at every step, but at inference time, the draft head receives its own (potentially imperfect) predictions starting from step 2. This train-test distribution mismatch causes accuracy to degrade at deeper speculation steps. Training-time test closes the mismatch: the draft head sees its own errors during training, not just ground-truth inputs.
Results: Eagle 3 achieves 4.0-6.5x speedup on benchmarks, roughly 1.4x improvement over Eagle 2. The draft head architecture is the same as Eagle ½ (FC layer + single transformer decoder layer), just trained differently. It inherits Eagle 2’s dynamic tree construction, with tree depth increased from 6 to 8.
Training details: Eagle 3 still requires offline hidden state collection from the target model, but now from multiple layers (low, mid, high) rather than just the second-to-top layer. Training uses AdamW (, ), learning rate with gradient clipping at 0.5, and takes roughly one day on a single node. The training-time test simulation adds overhead per step (the draft head runs multi-step autoregressive generation during training, with custom masks preventing information leakage), but the payoff is better data efficiency — each training example teaches the head how errors compound across steps, not just single-step accuracy.
Scaling: Eagle 3 trains on ~8x more data than Eagle ½ (ShareGPT + UltraChat-200K + OpenThoughts-114k) and shows linear speedup improvement with data, whereas Eagle ½ plateau. This is the direct consequence of removing the feature regression loss: without the MSE bottleneck, the head can absorb more data without saturating. The training-time test approach also makes heads more robust to distribution shift, since they’ve been exposed to their own imperfect predictions during training.
Critical batch size: LMSYS/SGLang benchmarks report 1.38x throughput improvement with Eagle 3 even at batch size 64 — well beyond where vanilla speculation breaks even. The higher acceptance rate means fewer wasted verification tokens per batch slot, extending the break-even point furthest of any Eagle variant. All three major frameworks support Eagle 3, and the framework-specific details are covered below.
vLLM: The Proposer Architecture
vLLM implements speculative decoding through a proposer that lives inside the GPU model runner. No cross-worker communication, no serialization overhead. The draft model runs in the same process as the target.
Core Abstraction
The class hierarchy is SpecDecodeBaseProposer → EagleProposer:
class SpecDecodeBaseProposer:
"""Base class for speculative decoding proposers."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
pass_hidden_states_to_model: bool,
runner=None,
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
class EagleProposer(SpecDecodeBaseProposer):
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
super().__init__(vllm_config, device, pass_hidden_states_to_model=True, runner=runner) Eagle-Specific Implementation
The propose() method on SpecDecodeBaseProposer drives the speculation loop. It takes hidden states from the target model’s last forward pass and generates draft tokens:
def propose(
self,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
next_token_ids: torch.Tensor,
last_token_indices: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
...
) -> torch.Tensor:
batch_size = common_attn_metadata.batch_size()
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(target_hidden_states) Weight sharing happens in load_model() via direct model attribute access, checking whether the draft model owns its own embedding layer:
def load_model(self, target_model: nn.Module) -> None:
self.model = get_model(vllm_config=self.vllm_config, model_config=draft_model_config)
# share embed_tokens with the target model if needed
# (simplified; actual code also guards on PP world_size == 1,
# handles an `embedding` fallback attribute, and checks
# weight equality before sharing)
if hasattr(self.model, "has_own_embed_tokens"):
if not self.model.has_own_embed_tokens:
self.model.model.embed_tokens = target_language_model.model.embed_tokens Tree Attention in vLLM
Tree-structured verification uses a dedicated attention backend in vllm/v1/attention/backends/tree_attn.py. The TreeAttentionMetadataBuilder constructs the mask from a configurable tree specification parsed at proposer initialization:
# Tree structure parsed from speculative_token_tree config
self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
tree_depth = len(self.tree_choices[-1])
# Precompute per-level properties
num_drafts_per_level = [0] * tree_depth
for node in self.tree_choices:
num_drafts_per_level[len(node) - 1] += 1 The tree metadata is built by TreeAttentionMetadataBuilder and passed to the attention kernel, allowing parallel verification of all tree paths in a single forward pass.
KV Cache Strategy
vLLM uses CUDA graph-compatible slot mapping with persistent buffers. Cache slots for draft tokens are managed through a pre-allocated buffer that’s filled via Triton kernels:
# Persistent buffer for slot mapping across CUDA graph captures
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
def _get_slot_mapping(self, num_tokens: int, slot_mapping: torch.Tensor | None = None):
if slot_mapping is not None:
self._slot_mapping_buffer[:slot_mapping.shape[0]].copy_(slot_mapping)
if num_tokens > slot_mapping.shape[0]:
self._slot_mapping_buffer[slot_mapping.shape[0]:num_tokens].fill_(PADDING_SLOT_ID)
return {name: self._slot_mapping_buffer[:num_tokens] for name in self.attn_layer_names} When tokens are rejected, the padding slot ID (-1) ensures rejected cache writes are discarded without affecting the allocator state.
SGLang: Integrated Speculation with CUDA Graphs
SGLang takes a different approach. Where vLLM bolts speculation onto the model runner, SGLang bakes it in. The draft and target share a single TpModelWorker. Hidden states never leave GPU memory between draft and verify phases. This coupling enables aggressive CUDA graph optimization at the cost of flexibility.
EAGLEWorker Architecture
The core implementation is in python/sglang/srt/speculative/eagle_worker.py. Key design: EAGLEWorker inherits from TpModelWorker and owns both draft and target execution:
class EAGLEWorker(TpModelWorker):
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
self.target_worker = target_worker
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
# Share embeddings and LM head with target
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
# (simplified; Eagle 3 path conditionally shares only embed)
self.draft_model_runner.model.set_embed_and_head(embed, head) The Speculation Loop
SGLang’s forward_batch_generation method shows the full flow. It returns a GenerationBatchResult with acceptance counts:
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(batch)
self.forward_draft_extend(batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu)
return GenerationBatchResult(logits_output=logits_output, next_token_ids=next_token_ids, ...)
else:
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = self.verify(batch, spec_info)
if batch.spec_info.verified_id.shape[0] > 0:
self.forward_draft_extend_after_decode(batch)
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=verify_output.verified_id,
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu), ...) Tree Construction
SGLang builds the speculation tree using a CUDA kernel compiled in sgl-kernel and called from eagle_utils.py. The kernel supports multiple mask modes for different backends:
class TreeMaskMode(IntEnum):
FULL_MASK = 0
QLEN_ONLY = 1
QLEN_ONLY_BITPACKING = 2
def build_tree_kernel_efficient(
verified_id: torch.Tensor,
parent_list: List[torch.Tensor],
top_scores_index: torch.Tensor,
draft_tokens: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
topk: int,
spec_steps: int,
num_verify_tokens: int,
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
tree_mask_buf: Optional[torch.Tensor] = None,
position_buf: Optional[torch.Tensor] = None,
):
"""Build tree structure for verification.
Returns:
tree_mask: Attention mask for tree structure
position: Position IDs for each tree node
retrive_index: Indices to gather accepted tokens
"""
... CUDA Graph Optimization
SGLang captures CUDA graphs for the draft forward passes. The key challenge: speculation has variable batch sizes. CUDA graphs capture a fixed sequence of kernel launches and replay them without CPU intervention, eliminating launch overhead that dominates small kernels. The standard trick for variable inputs: capture graphs at a set of “bucket” sizes and pad to the nearest bucket. The wasted compute from padding is far cheaper than the saved launch latency. vLLM’s “piecewise” mode takes a different approach, capturing graphs for sub-operations rather than the full loop, trading some launch overhead for more flexibility. SGLang solves this with EAGLEDraftCudaGraphRunner and a separate EAGLEDraftExtendCudaGraphRunner for the extend phase:
class EAGLEDraftCudaGraphRunner:
"""Captures CUDA graphs for draft model inference."""
def __init__(self, eagle_worker: EAGLEWorker):
self.model_runner = eagle_worker.model_runner
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
# Batch sizes determined by get_batch_sizes_to_capture()
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.num_tokens_per_bs = self.topk
self.max_bs = max(self.capture_bs) KV Cache Allocation with Paged Attention
SGLang’s KV cache management for speculation is in the draft preprocessing. It supports both standard and paged allocation depending on page_size:
def _draft_preprocess_decode(self, batch: ScheduleBatch):
num_seqs = batch.batch_size()
if self.page_size == 1:
# Standard allocation: one slot per token
out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots(
batch.tree_cache,
num_seqs * self.speculative_num_steps * self.topk,
backup_state=True, # For rollback on rejection
)
else:
# Paged allocation: handles partial page duplication for top-k > 1
out_cache_loc, token_to_kv_pool_state_backup = alloc_paged_token_slots_extend(
batch.tree_cache, prefix_lens, prefix_lens_cpu,
seq_lens, seq_lens_cpu, last_loc, extend_num_tokens,
backup_state=True,
) The backup_state=True is critical: it saves the allocator state so rejected branches can be rolled back efficiently. This is the same idea as arena allocation in systems programming. Instead of iterating through each rejected slot and returning it to the free pool (O(k) for k rejected tokens), you restore the allocator’s pointer to its pre-speculation position: O(1) regardless of speculation depth. The tradeoff: no selective rollback, so the accepted prefix must be contiguous. With large page sizes and top-k > 1, SGLang also duplicates the last partial page across branches to avoid cache corruption.
TensorRT-LLM: Compiled Speculation
TensorRT-LLM goes furthest: compile the entire speculation loop into a single TensorRT engine. No Python in the hot path. No kernel launch overhead. Just a compiled graph that runs flat out on the GPU.
Architecture Overview
Unlike vLLM and SGLang which orchestrate speculation in Python, TensorRT-LLM compiles draft generation, tree construction, and verification into custom TensorRT plugins. The compiled model implementation lives in tensorrt_llm/models/eagle/, while a newer PyTorch execution path lives in tensorrt_llm/_torch/speculative/.
Eagle Model Definition
The compiled Eagle model uses TensorRT plugins for sampling, draft decoding, and input preparation. EagleForCausalLM extends LLaMAForCausalLM and contains multiple EagleNet modules:
# tensorrt_llm/models/eagle/model.py
class EagleNet(Module):
def __init__(self, config, logits_dtype):
self.drafter = LLaMAModel(config)
self.lm_head = ColumnLinear(config.hidden_size, vocab_size_padded, ...)
def forward(self, input_ids, position_ids=None, hidden_states=None,
last_token_indices=None, spec_decoding_params=None,
kv_cache_params=None, attention_params=None):
hidden_states, cache = self.drafter(input_ids, position_ids=position_ids,
hidden_states_for_embed=hidden_states, ...)
return cast(self.lm_head(hidden_states), dtype=self.logits_dtype), hidden_states, cache
class EagleForCausalLM(LLaMAForCausalLM):
def __init__(self, config: EagleConfig):
super().__init__(config)
self.num_eagle_layers = config.num_eagle_layers
self.eagle_nets = ModuleList([
EagleNet(config=eagle_net_config, logits_dtype=config.logits_dtype)
for _ in range(self.num_eagle_layers)
]) Sampling and acceptance are handled by dedicated TensorRT plugins (EagleSampleAndAcceptDraftTokens, EagleDecodeDraftTokens, EaglePrepareDrafterInputs) that run entirely on GPU without Python overhead.
Tree Verification in TensorRT
The verification happens inside the compiled engine using custom attention masks. The eagle_prepare_drafter_inputs_plugin outputs the spec decoding parameters:
# Outputs from eagle_prepare_drafter_inputs_plugin
spec_decoding_generation_lengths, # [batch_size] - draft lengths per request
spec_decoding_position_offsets, # [batch_size, max_decoding_tokens] - position offsets
spec_decoding_packed_mask, # [batch_size, max_decoding_tokens, ceil(max/32)] - uint32 packed masks The spec_decoding_packed_mask tensor uses uint32 bitpacking to encode which draft tokens can attend to which other tokens during verification. The bitpacking isn’t just a memory optimization. With 64 tree nodes, a dense boolean mask is 4096 bytes; packed into uint32, it’s 512 bytes, an 8x reduction. But the real win is that bitwise AND to check attendance is a single-cycle GPU instruction versus branch-heavy boolean logic. SGLang’s QLEN_ONLY_BITPACKING mode exploits the same trick.
EAGLE-2 Dynamic Trees and EAGLE-3
TensorRT-LLM supports EAGLE-1 (fixed trees), EAGLE-2 (dynamic trees), and EAGLE-3. The eagle_draft_decoder_plugin handles dynamic tree construction at runtime with a use_dynamic_tree flag. For EAGLE-3, a separate PyTorch execution path with Eagle3ResourceManager manages hidden states for the draft model, supporting both compiled TensorRT engines and PyTorch-based inference.
Implementation Comparison
| Aspect | vLLM | SGLang | TensorRT-LLM |
|---|---|---|---|
| Orchestration | Integrated EagleProposer | Integrated in TpModelWorker | Compiled TensorRT plugins |
| CUDA Graphs | Piecewise via CudagraphDispatcher | Full support with EAGLEDraftCudaGraphRunner | Implicit via TensorRT |
| Tree Building | TreeAttentionMetadataBuilder | sgl-kernel CUDA kernels | TensorRT plugins |
| KV Cache | Padding slot with buffer | Paged attention with backup state | Static allocation in engine |
| Eagle-3 Support | Yes | Yes | Yes |
The Design Space
These three implementations represent different points in a fundamental tradeoff space: flexibility vs. performance and iteration speed vs. execution speed.
Execution Speed
↑
│
TensorRT-LLM ────┼──── TRT plugins, compiled
│ Minutes to change
│
SGLang ─────┼──── CUDA graphs, sgl-kernel
│ Seconds to change
│
vLLM ──────┼──── Integrated proposer, tree attn backend
│ Config change to switch
│
└─────────────────────→ Flexibility vLLM’s integrated proposer runs draft and target models within the same GPU model runner. The EagleProposer shares the runner’s memory space directly, with no cross-worker communication. Tree attention uses a dedicated TreeAttentionMetadataBuilder backend, and CUDA graphs are managed via CudagraphDispatcher in piecewise mode.
The proposer is configured via SpeculativeConfig, so you can swap between Eagle, Medusa, or n-gram speculation at server startup. If you’re evaluating speculation methods and don’t want to commit yet, this is where you start.
SGLang’s integrated worker has the draft and target models sharing the same TpModelWorker, so hidden states never leave GPU memory between draft and verify phases. Dedicated EAGLEDraftCudaGraphRunner and EAGLEDraftExtendCudaGraphRunner capture CUDA graphs for both the draft decode and extend phases. The tree construction kernel lives in sgl-kernel (compiled CUDA), and SGLang supports multi-layer Eagle configurations for deeper speculation trees.
The tradeoff: the speculation logic is deeply coupled to the model worker, making it harder to experiment with alternative speculation strategies. SGLang occupies the middle ground: most of the performance of a compiled solution, but you can still modify Python code and restart the server when something breaks.
TensorRT-LLM’s compiled approach uses custom TensorRT plugins (EagleSampleAndAcceptDraftTokens, EagleDecodeDraftTokens, EaglePrepareDrafterInputs) that run sampling, acceptance, and input preparation entirely on GPU. This eliminates all Python overhead and enables TensorRT’s layer fusion optimizations.
A separate PyTorch execution path in _torch/speculative/ offers easier experimentation, especially for Eagle-3. But the compiled path still requires minutes to hours for engine building, and debugging requires TensorRT profiling tools rather than Python debuggers. This is where you end up when you’ve settled on a strategy and want every last microsecond.
The memory allocation strategies reflect the same tradeoffs. When draft tokens are rejected, frameworks must roll back KV cache state, return cache slots to the free pool, and continue from the last accepted token. vLLM uses padding slot IDs (-1) so rejected writes are silently discarded: simple and CUDA-graph-friendly. SGLang’s backup_state pattern snapshots allocator state before speculation, enabling O(1) rollback regardless of how many tokens are rejected, with additional support for paged allocation across top-k branches. TensorRT-LLM’s static allocation is fastest but requires knowing the maximum speculation depth at compile time: maximum performance, minimum flexibility.
Interaction with Continuous Batching
Speculation complicates continuous batching. In standard continuous batching, requests join and leave the batch dynamically. But speculation requires coordinating draft and verify phases across all sequences in the batch.
The tension: different sequences may have different acceptance rates. One sequence might accept 5 tokens while another accepts only 1. The frameworks handle this differently:
- vLLM: Uses
num_rejected_tokens_gputensors to track rejections per sequence, adjusting position offsets and slot mappings for the next draft round. - SGLang: Uses
accept_length_per_req_cpufrom the verify output. The draft extend phase after decode runs conditionally only when there are verified tokens remaining. - TensorRT-LLM: Handles this inside the compiled plugins with dynamic indexing via
accepted_lensandaccepted_path_idstensors.
At high batch sizes (>8), speculation benefits diminish anyway because you’re no longer memory-bound. Most deployments enable speculation only for low-batch-size scenarios.
Everything above assumes a well-matched Eagle head. Fine-tune your target model and that assumption is gone.
Distribution Shift and the Fine-Tuning Problem
Eagle’s core assumption is that the draft head can predict what the target model will generate. This works well when the Eagle head is trained on the same distribution as the target model. But what happens when you fine-tune the target model?
The Problem
Consider a base Llama-70B model with a pre-trained Eagle head achieving 85% acceptance rate. You fine-tune Llama-70B on your domain-specific data. Now the model’s output distribution has shifted. It prefers different tokens, different phrasings, different structures. But the Eagle head was trained on the base model’s distribution.
The result: acceptance rates plummet. Instead of 85%, you might see 40-50%. At that acceptance rate, speculation overhead exceeds the benefit, and you’re better off running vanilla autoregressive decoding.
This is the distribution shift problem: the draft model (or Eagle head) and target model have diverged, and the draft’s predictions no longer match what the target will accept.
Why Sharing the LM Head Isn’t Enough
Eagle’s weight sharing helps but doesn’t solve the problem. Yes, the Eagle head uses the target’s LM head to produce logits, so the final token probabilities are computed identically. But the Eagle head’s job is to predict the hidden state that will produce those logits. After fine-tuning, the target model’s hidden state trajectories have changed: different attention patterns, different residual stream content, different feature representations.
The Eagle head was trained to predict: “given hidden state H at position t, the next hidden state will be H’“. After fine-tuning, the target model now produces H” instead of H’. The Eagle head’s prediction of H’ is wrong, the logits computed from H’ don’t match what the target would produce from H”, and the token gets rejected.
Knowledge Distillation as the Solution
The fix is to re-train the Eagle head on the fine-tuned model’s hidden states. This is knowledge distillation: the fine-tuned target model is the teacher, the Eagle head is the student learning to predict the teacher’s hidden state transitions.
The training process:
# Pseudocode for Eagle head distillation
for batch in training_data:
input_ids = batch.input_ids # [batch_size, seq_len]
# Get hidden states from fine-tuned target model
with torch.no_grad():
outputs = target_model(input_ids, output_hidden_states=True)
# hidden_states: [batch_size, seq_len, hidden_dim]
target_hidden = outputs.hidden_states[-1] # Last layer
# Eagle head predicts next hidden state from current + token embedding
# Input: hidden[t] concat embed[t+1] -> Output: predicted hidden[t+1]
token_embeds = target_model.embed_tokens(input_ids[:, 1:])
eagle_input = torch.cat([target_hidden[:, :-1, :], token_embeds], dim=-1)
predicted_hidden = eagle_head(eagle_input) # [batch_size, seq_len-1, hidden_dim]
# Loss: match the target's actual hidden states
actual_hidden = target_hidden[:, 1:, :] # [batch_size, seq_len-1, hidden_dim]
loss = F.mse_loss(predicted_hidden, actual_hidden)
loss.backward() You don’t need to train the full target model, just the lightweight Eagle head. For a 70B target model, the Eagle head might be 500M-1B parameters. Distillation takes hours, not days.
Vanilla Speculative Decoding as Distillation
You can also use vanilla speculative decoding itself as a distillation mechanism.
The setup: you have a fine-tuned target model and a base Eagle head with poor acceptance rates. Instead of explicitly training the Eagle head on hidden states, run speculative decoding and use the acceptance/rejection signal as training data.
# Online distillation via speculative decoding
for prompt in training_prompts:
# Run speculation with the mismatched Eagle head
context_hidden = target_model.get_hidden_states(prompt)
draft_tokens, draft_hidden = eagle_head.speculate(context_hidden, k=5)
# Verify against target model
with torch.no_grad():
target_logits = target_model(prompt + draft_tokens)
accepted_mask = verify_tokens(draft_tokens, target_logits) # [k] bool
# Get what target actually produced at rejected positions
target_hidden_at_rejected = target_model.get_hidden_states(
prompt + draft_tokens[:accepted_mask.sum()]
)
# Train only on rejected positions (where Eagle was wrong)
if (~accepted_mask).any():
rejected_indices = (~accepted_mask).nonzero()
loss = F.mse_loss(
draft_hidden[rejected_indices],
target_hidden_at_rejected[rejected_indices]
)
loss.backward() The advantage: you’re training on exactly the distribution where the Eagle head fails. Standard distillation trains on random samples; this focuses compute on the hard cases.
The downside: you need to run the full target model during training, which is expensive. But you’re already paying that cost if you’re serving the model, so you could distill online during production traffic (with appropriate safeguards). Online distillation during serving is tempting but dangerous. A gradient update that transiently spikes the rejection rate means you eat the overhead of both speculation and verification for tokens that don’t get accepted. Production traffic may also not be representative, causing the Eagle head to overfit to recent query patterns and degrade on tail queries. Shadow the training to a replica, A/B test before swapping, and monitor acceptance rate with circuit breakers.
LoRA Adapters and Eagle Heads
LoRA (Low-Rank Adaptation) creates a tricky case for Eagle speculation. With LoRA, you’re not fine-tuning the full model. You’re adding small adapter matrices that modify the model’s behavior while keeping base weights frozen.
The problem with LoRA + Eagle:
The Eagle head was trained on the base model’s hidden states. LoRA modifies those hidden states. Even though the modification is low-rank, it can still shift the distribution enough to hurt acceptance rates.
The severity depends on LoRA rank and alpha:
- Low-rank LoRA (r=8, alpha=16): Small ΔH, Eagle head often still works okay
- High-rank LoRA (r=64, alpha=128): Large ΔH, acceptance rates degrade significantly
- Task-specific LoRA (code, math, specific domain): Even low-rank can cause distribution shift if the task is very different from base model training
Solutions:
- Train LoRA adapters for the Eagle head too. If you’re adding LoRA to the target, add corresponding LoRA to the Eagle head. The Eagle LoRA learns to predict how the target LoRA modifies hidden states.
# Eagle head with LoRA adaptation
class EagleHeadWithLoRA(nn.Module):
def __init__(self, base_eagle_head, hidden_size, lora_rank=8, lora_alpha=16):
super().__init__()
self.base = base_eagle_head
self.lora_A = nn.Linear(hidden_size, lora_rank, bias=False)
self.lora_B = nn.Linear(lora_rank, hidden_size, bias=False)
self.scaling = lora_alpha / lora_rank
# Initialize LoRA weights (A: normal, B: zero for stable start)
nn.init.kaiming_uniform_(self.lora_A.weight)
nn.init.zeros_(self.lora_B.weight)
def forward(self, hidden_states):
# hidden_states: [batch, seq_len, hidden_size]
base_pred = self.base(hidden_states)
lora_delta = self.lora_B(self.lora_A(hidden_states)) * self.scaling
return base_pred + lora_delta Use the same LoRA weights. If the target’s LoRA primarily affects the LM head layers, you might be able to share those exact LoRA weights with the Eagle head’s LM head (which is already shared). This doesn’t help if LoRA is applied to attention/MLP layers that affect hidden state trajectories.
Accept the degradation for light LoRA. If your LoRA is low-rank and your base acceptance rate is high (>85%), you’ll likely hold 70%+ after LoRA, which still justifies the overhead.
Framework support for LoRA + speculation:
vLLM supports LoRA with speculative decoding, including Eagle-3 with LoRA adapters on models like Qwen3. The acceptance rate concern remains: if your LoRA significantly shifts the output distribution, the Eagle head may produce more rejected tokens. SGLang and TensorRT-LLM have more limited LoRA + speculation support.
Eagle-3’s Training-Time Test Approach
Eagle-3 comes at this differently. Instead of training the Eagle head with a feature regression loss that forces it to match hidden states exactly, Eagle-3 drops that loss entirely and trains only on token prediction. The draft head learns whatever intermediate representations correlate with acceptance, rather than representations that minimize MSE against a specific layer’s activations.
Eagle-3 also trains with its “training-time test” technique: during training, the draft head receives its own predictions as inputs at step 2+, simulating the distribution mismatch that occurs at inference time. This makes Eagle-3 heads more robust to input perturbations generally.
In practice, Eagle-3 heads transfer better across fine-tuned models because they’ve learned features that correlate with acceptance rather than exact hidden state matching. But they still degrade on heavily fine-tuned models, just less severely than Eagle-½.
Falling Back to Vanilla Draft Models
When Eagle heads don’t transfer well to your fine-tuned model and you don’t have resources to retrain them, vanilla draft model speculation is still an option.
The approach: use a smaller model from the same family as your draft model. For a fine-tuned Llama-70B, use Llama-7B fine-tuned on the same data. The draft model sees the same distribution shift as the target, so they stay aligned.
Advantages over broken Eagle heads:
- No Eagle head training required
- Draft model naturally adapts if you fine-tune both together
- Works with LoRA if you apply the same LoRA to both models
Disadvantages:
- Draft model is much larger than an Eagle head (7B vs 500M parameters)
- Requires running two separate models, increasing memory and orchestration complexity
- Draft model inference is slower than Eagle head inference
Hybrid approach: Use Eagle for base model traffic, fall back to vanilla draft model for heavily fine-tuned or LoRA-adapted requests. vLLM’s SpeculativeConfig lets you swap speculation strategies at startup.
Practical Implications
Retraining costs in practice: Every target model fine-tune requires Eagle head retraining — the head won’t transfer across model families, and even within the same family, a heavy fine-tune will degrade acceptance rates below the break-even point. Budget hours (not days) per head retrain, dominated by the hidden state collection pass over your training set. Eagle-3’s training-time test approach produces heads that are more robust to distribution shift but still degrade on heavy fine-tunes. Build head retraining into your fine-tuning pipeline as a standard step, not an afterthought.
The frameworks don’t solve distribution shift for you. They assume you’re providing a well-matched Eagle head. SGLang’s documentation notes that Eagle heads are available on HuggingFace for popular base models, but fine-tuned variants require training your own.
Speculative decoding is one of those ideas that sounds like it shouldn’t work. You’re betting that a smaller, dumber model can predict what a larger, smarter model will say.
Design principle
And yet, when the draft model is well-matched to the target, it works.
The frameworks have made Eagle practical. The research contributed the core insight (predict hidden states, not tokens), and the engineering contributed the infrastructure (tree attention, KV cache management, CUDA graphs). Neither is sufficient alone. Understanding both is how you deploy this, and how you debug it when acceptance rates drop for no obvious reason. When acceptance rates degrade silently, instrument in order: (1) per-position acceptance rate: sharp drops at position 2+ mean the Eagle head is losing context; (2) per-category acceptance rate: code tokens have higher acceptance than natural language, so a traffic mix shift can look like degradation; (3) hidden state cosine similarity between Eagle’s prediction and the target’s actual state. Below 0.9 means retraining time. vLLM exposes num_rejected_tokens_gpu; SGLang exposes accept_length_per_req_cpu. Build dashboards from these before you need them.
The 4x speedup holds up. So do the operational costs. Choose your framework based on how often you ship, measure acceptance rates before you deploy, and have a fallback plan for when Eagle heads stop transferring. I’ve been running Eagle 3 on vLLM for inference workloads and the results justify the operational complexity. But I also keep a vanilla draft model config one flag away.
References
Leviathan, Y., Kalman, M., & Matias, Y. (2023). Fast Inference from Transformers via Speculative Decoding. International Conference on Machine Learning. arXiv:2211.17192
Chen, C., Borgeaud, S., Irving, G., Lespiau, J.-B., Sifre, L., & Jumper, J. (2023). Accelerating Large Language Model Decoding with Speculative Sampling. arXiv:2302.01318
Li, Y., Wei, F., Zhang, C., & Zhang, H. (2024). EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty. International Conference on Machine Learning. arXiv:2401.15077
Li, Y., Wei, F., Zhang, C., & Zhang, H. (2024). EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees. Conference on Empirical Methods in Natural Language Processing. arXiv:2406.16858
Li, Y., Wei, F., Zhang, C., & Zhang, H. (2025). EAGLE-3: Scaling up Inference Acceleration of Large Language Models via Training-Time Test. arXiv:2503.01840
Kwon, W., et al. (2023). Efficient Memory Management for Large Language Model Serving with PagedAttention. ACM SIGOPS Symposium on Operating Systems Principles. arXiv:2309.06180
vLLM Project. (2025). Speculative Decoding Implementation. github.com/vllm-project/vllm/tree/main/vllm/v1/spec_decode
SGLang Project. (2024). EAGLE Worker Implementation. github.com/sgl-project/sglang/blob/main/python/sglang/srt/speculative/eagle_worker.py
NVIDIA. (2024). TensorRT-LLM EAGLE Implementation. github.com/NVIDIA/TensorRT-LLM/tree/main/tensorrt_llm/models/eagle
vLLM Documentation. Speculative Decoding. docs.vllm.ai/en/latest/features/spec_decode.html
NVIDIA TensorRT-LLM Documentation. Speculative Sampling. nvidia.github.io/TensorRT-LLM/advanced/speculative-decoding.html
Hu, E. J., et al. (2021). LoRA: Low-Rank Adaptation of Large Language Models. arXiv:2106.09685