Update 2026-06-10: H100 single-user decode is now 63.0 tok/s
(was 53 forward-only / 44.9 full-generate), 83.9 tok/s with speculative decoding
on real text at ctx 1200+, and the production-context FP8 path went 14 → 61.4 tok/s (4.4x)
via a split-KV attention kernel rewrite + cuBLASLt small-M GEMM routing. GPU tables below that
are not marked otherwise were measured April 2026.
Full session record: v3/H100_MAXPERF_PLAN.md.
The GPU PPL figure (14.75) is flagged for re-verification — see the perplexity note.
3 TPU models + 2 GPU baselines. TPU v6e-4 $5.20/hr, H100 SXM $1.92/hr. All at max-ctx 2048.
Full batch sweep -- all configurations
Batch
E4B TPU
26B-A4B TPU
31B TPU
31B GPU (rvLLM)
31B GPU (vLLM)
1
78
53
44
63.0
69
8
542
390
318
434
515
32
-
-
-
1,743
1,748
64
3,661
2,662
2,112
3,265
3,130
128
6,298
4,915
3,853
5,802
4,689
256
10,214
8,192
6,246
7,808
7,077
512
13,773
12,390
8,550
8,786
8,243
768
15,514
14,899
9,600
-
-
1024
16,794
-
-
-
-
All tok/s are total throughput (batch x per-element rate). "-" = not tested or OOM at that batch size.
Summary across all configurations
Config
PPL
B=1 tok/s
Cached TTFT
Peak tok/s
$/hr
Peak tok/s/$
E4B on TPU
5.87
78.3
25.9 ms
16,794
$5.20
3,230
26B-A4B on TPU
90.21
52.9
35.3 ms
14,899
$5.20
2,865
31B on TPU
24.76
44.2
73.3 ms
9,600
$5.20
1,846
31B rvLLM GPU
14.75*
63.0
63 ms
8,786
$1.92
4,576
31B vLLM GPU
-
69
-
3,848
$1.92
2,004
PPL measured on 86-token passage. 26B-A4B PPL is high for both rvLLM (90.21) and HF reference (85.42) -- instruct MoE on raw prose. GPU rvLLM PPL 14.75* is flagged for re-verification: FP8 scoring 25% better than the HF BF16 reference (19.62) is implausible as a quantization effect and likely reflects an eval-config difference between harnesses (softcap path). Treat as unverified until the eval is unified.
Three models, one codebase. ~500 lines of JAX, zero custom kernels. XLA compiles the entire forward pass to a single fused while loop. E4B adds per-layer input injection + KV sharing. 26B-A4B adds MoE (128 experts, top-8 routing). 31B adds 128K context via dual-path architecture. All run on the same v6e-4 ($5.20/hr).
Single-user performance (B=1, max-ctx 2048)
Metric
E4B (4B)
26B-A4B (MoE)
31B
Decode throughput
78.3 tok/s
52.9 tok/s
44.2 tok/s
Decode latency
12.8 ms
18.9 ms
22.6 ms
TTFT (first run, incl. compile)
2,470 ms
2,083 ms
1,498 ms
TTFT (cached compile)
25.9 ms
35.3 ms
73.3 ms
Perplexity
5.87
90.21
24.76
HF bf16 reference PPL
3.28
85.42
~19
PPL measured on 86-token passage (John 1:1-14) with BOS prepend. 26B-A4B PPL is high for both rvLLM and HF -- instruct MoE models score poorly on raw prose. TTFT first-run includes XLA compilation (~1.5-2.5s). Cached TTFT is the steady-state latency for single-token forward pass (1 prompt token).
Batch scaling (max-ctx 2048)
Throughput vs batch size - all configurations
E4B + 31B on TPU v6e-4 (int8, 2048 ctx) vs vLLM 31B on H100 SXM (FP8, 2048 ctx). All measured.
Batch
E4B TPU
26B-A4B TPU
31B TPU
vLLM H100
1
78
53
44
66.9
8
542
390
318
511.7
64
3,661
2,662
2,112
2,794
128
6,298
4,915
3,853
3,848
256
10,214
8,192
6,246
3,709
512
13,773
12,390
8,550
3,788
768
15,514
14,899
9,600
3,671
1024
16,794
-
-
-
All tok/s are total throughput (batch x per-element rate). vLLM peaks at B=128 then saturates. All three TPU models keep scaling to B=768. E4B peaks at B=1024 (16,794 tok/s). 26B-A4B MoE matches E4B closely at high batch despite 6.5x more total parameters (sparse activation).
Cost efficiency
Config
Peak tok/s
Batch
Cost/hr
tok/s/$
E4B on TPU v6e-4
16,794
1024
$5.20
3,230
26B-A4B on TPU v6e-4
14,899
768
$5.20
2,865
31B on TPU v6e-4
9,600
768
$5.20
1,846
31B vLLM on H100 (measured)
3,848
128
$1.92
2,004
TPU 31B vs vLLM GPU (2048 ctx) Measured
Apples-to-apples at max-ctx 2048. vLLM on H100 SXM 80GB with RedHatAI/gemma-4-31B-it-FP8-Dynamic ($1.92/hr on vast.ai). GPU wins at low batch due to lower per-step latency. TPU overtakes at B=128 and keeps scaling.
31B: rvLLM TPU vs rvLLM GPU vs vLLM GPU
All at max-ctx 2048. TPU: v6e-4, int8. GPU: H100 SXM, FP8. rvLLM GPU = raw CUDA graph decode, vLLM = server with scheduler.
Batch
vLLM GPU tok/s
vLLM ms/step
rvLLM TPU tok/s
TPU ms/step
Winner
1
66.9
14.95
44
22.6
GPU +52%
8
511.7
15.63
318
25.1
GPU +61%
64
2,794
22.90
2,112
30.3
GPU +32%
128
3,848
33.26
3,853
33.2
Tie
256
3,709
69.03
6,246
41.0
TPU +68%
512
3,788
135.18
8,550
59.7
TPU +126%
768
3,671
209.18
9,600
80.3
TPU +161%
Crossover at B=128. vLLM's mature CUDA graph pipeline dominates at low batch. TPU's advantage at high batch: XLA fuses the entire decode loop on-chip, and the v6e-4's aggregate bandwidth (~3.3 TB/s across 4 chips) keeps scaling where the single H100 saturates.
31B context scaling (B=1)
Context
ms/step
tok/s
Architecture
KV type
512
12.79
78.2
Single-scan, 60-layer scan + cond
bf16
2,048
22.6
44.2
Single-scan
bf16
32K
~66
~15
Single-scan
bf16
64K
~91
~11
Split-cache, 10 groups x 6
int8
128K
40.56
24.7
Split-cache + blockwise global
int8
Dual-path auto-switching: <= 32K uses single-scan with bf16 KV (fastest). > 32K uses split-cache with int8 KV (128K). 128K is 31B only -- E4B not yet tested at long context.
31B perplexity progression
Version
PPL
Notes
bf16 KV, single scan
25.51
Original baseline
int8 KV, single scan
22.80
Per-head scales
int8 KV, split-cache
19.24
Circular buffer + blockwise
int8 weights, bf16 KV (2048 ctx)
24.76
Current single-scan path
Architecture comparison
Property
E4B (4B)
26B-A4B (MoE)
31B
Total / active params
~4B / 4B
26B / ~4B
31B / 31B
Layers
42
30
60
Hidden size
2,560
2,816
5,376
Q / KV heads (sliding)
8 / 2
16 / 8
32 / 16
Q / KV heads (global)
8 / 2
16 / 2 (V=K)
32 / 4 (V=K)
Head dim (sliding / global)
256 / 512
256 / 512
256 / 512
Sliding window
512
1,024
1,024
MoE
none
128 experts, top-8
none
KV-shared layers
18 (of 42)
0
0
Per-layer input
256-d gated (5.6 GB)
none
none
Activation / Softcap
GELU(tanh) / 30*tanh(x/30)
Embeddings
Tied (lm_head = embed_tokens.T)
XProf per-step breakdown (31B, B=1, 512 ctx)
Component
Time
%
60-layer scan (while loop on MXU)
10.6 ms
86%
incl. jax.lax.cond dispatch
1.8 ms
15% of scan
incl. ICI all-reduce (O + down proj)
0.6 ms
6% of scan
incl. KV cache dynamic updates
1.3 ms
12% of scan
Host + dispatch overhead
1.7 ms
14%
Total step
12.3 ms
100%
Theoretical BW limit (30 GB / 3.3 TB/s)
9.1 ms
EAGLE-3 speculative decoding (31B, experimental)
Single-user latency optimization. 450M-param draft head proposes K=5 tokens per cycle; the full 31B verifies K+1=6 in one forward pass. Lossless for greedy decode. Projected ~145 tok/s (1.8x) at tau=3.5 with 50K+ training examples. Current: 2K examples, loss 7.1, pipeline validated end-to-end.
FP8 weights + F16 KV cache. Per-channel FP8 weights from the RedHatAI/gemma-4-31B-it-FP8-Dynamic checkpoint with channelscale fused into post-norm kernels. F16 KV cache + F16 paged attention (FA3 SM90 for sliding, SM89 for global). All 60 layers captured in a single CUDA graph. 63.0 tok/s single-user decode (2026-06; 83.9 with speculative decoding on real text), 8,786 tok/s peak (B=512, 2026-04). TTFT 63ms cached/short-prompt; prompts past the 1024-token sliding window still prefill per-token (known weakness, fix in progress).
June 2026 update Measured 2026-06-10, commit 544b1309e
One week of profiling-driven work, deployed to production.
Production-context decode went 14 → 61.4 tok/s (4.4x); single-user
short-context 44.9 → 64.4; speculative decoding shipped at
83.9 tok/s on real text. What moved it: a split-KV GQA-grouped rewrite
of the FP8 paged-attention kernel (551µs → 15.7µs per sliding call at a full
1024 window, 171x at 8K ctx, parity ≤ 4.9e-4), small-M FP8 GEMMs rerouted through
cuBLASLt (the CUTLASS tile measured ~51% of HBM at M=1), and the speculative verify
forward graph-captured per chunk size. One correction owned publicly: the April
analysis attributed ~12 ms/step to "inter-kernel dispatch gap" -- node-level tracing
shows the true gap is ~1.1 ms; the missing time was the GEMM route and the attention
kernel, both now fixed. Full record:
v3/H100_MAXPERF_PLAN.md;
reproduce with v3/scripts/bench_spread.sh.
Fresh batch spread (B=1..256, both GEMM routes, defaults auto-pick the winner)
Batch
tok/s (default)
ms/step
vs CUTLASS-only
route
1
64.4
15.5
+32%
cuBLASLt
2
125.3
16.0
+29%
cuBLASLt
4
249.1
16.1
+27%
cuBLASLt
8
495.5
16.1
+27%
cuBLASLt
16
949.2
16.9
+21%
cuBLASLt
32
1,741
18.4
+16%
cuBLASLt
64
2,997
21.4
+4%
cuBLASLt
128
5,211
24.6
CUTLASS wins +12%
CUTLASS
256
7,607
33.7
CUTLASS wins +20%
CUTLASS
Crossover (RVLLM_FP8_GEMM_LT_MAX_M, default 64) calibrated by this sweep. 40 iters / 8 warmup; April-era B≥64 rows ran ~5-10% above these at 100 iters -- treat cross-date deltas as methodology noise until re-run matched.
was 20.5 s (30x); kernel parity-proven, default pending wiring gate (#58)
Speculative decoding: n-gram drafting, batched graphed verify; K=0 is bit-identical to plain decode by token hash (incl. past the sliding-window ring wrap). Open items, in public: per-token prefill TTFT on >1024-token prompts (20.5 s bench for 1200 tokens), serve graph lifecycle, GPU PPL re-eval. Fixes in flight.
Forward pass (14 launches per layer; April 2026 -- small-M GEMMs now route via cuBLASLt + scale/cast, see June update)
Batch scaling (April 2026; fresh June spread above)
Batch
tok/s
ms/step
Scaling
1
53
18.7
1.0x
4
221
18.1
4.2x
8
434
18.4
8.2x
16
893
17.9
16.8x
32
1,743
18.4
32.9x
64
3,265
19.6
61.6x
128
5,802
22.1
109.5x
256
7,808
32.8
147.3x
512
8,786
58.3
165.8x
Perplexity validation (April 2026; flagged for re-verification -- see note below)
Weight path
KV cache
PPL
tok/s (B=1)
FP8-Dynamic + CUTLASS channelscale epilogue
F16
14.75
40.0
BF16 split QKV per-tensor FP8
F16
17.96
37.9
F16 weights (no FP8)
F16
19.79
37.9
HuggingFace BF16 reference
--
19.62
--
rvLLM vs vLLM on H100 vLLM measured 2026-04
Same hardware, same model. H100 SXM 80GB, RedHatAI/gemma-4-31B-it-FP8-Dynamic. rvLLM: raw CUDA graph decode. vLLM 0.19: OpenAI-compatible server. vLLM numbers include server overhead.
Batch
rvLLM tok/s
vLLM tok/s
Delta
1
63.0 (83.9 spec)
69 (2026-04)
-9% (+22% spec)
8
434
515
-16%
32
1,743
1,748
~0%
64
3,265
3,130
+4%
128
5,802
4,689
+24%
256
7,808
7,077
+10%
512
8,786
8,243
+7%
CUDA graph
Mode
tok/s (B=1)
ms/step
Speedup
CUDA graph, June 2026 defaults
64.4
15.5
~5.9x
CUDA graph, April 2026 (~935 nodes)
53
18.7
~4.8x
Eager (no graph, April 2026)
11
~91
1.0x
Coming soon: E4B (4B) + 26B-A4B on GPU. FP8-Dynamic checkpoints downloaded. Per-layer input injection and KV sharing (E4B) and MoE expert routing (26B-A4B) are actively being implemented in the Rust+CUDA engine. Check back shortly.
Method
TPU benchmarks:
Hardware: Cloud TPU v6e-4 (4 chips, 128 GB HBM, ~3.3 TB/s), us-east5-b, $5.20/hr.
Models: google/gemma-4-E4B-it (42 layers, ~4B, per-layer input injection), google/gemma-4-26B-A4B-it (30 layers, 26B total / ~4B active, 128 experts top-8 MoE), google/gemma-4-31B-it (60 layers, 31B dense).
Quantization: int8 weights (jnp.int8), bf16 activations, bf16 KV cache.
Batch sweeps: all at max-ctx 2048. Timing via jax.block_until_ready(). Compile step discarded, measurement from first cached run.
LIBTPU flags: --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_scoped_vmem_limit_kib=131072
Perplexity: 86-token passage (John 1:1-14) with BOS prepend. E4B = 5.87 (HF ref 3.28), 26B-A4B = 90.21 (HF ref 85.42), 31B = 24.76 (split-cache = 19.24).
TTFT: measured from first block_until_ready() call on 2-token prompt. Includes XLA compile on first run, cached thereafter.
Context scaling (31B only): 512 to 128K tested. Auto-switches at 32K boundary.
vLLM GPU comparison:
Hardware: H100 SXM 80 GB on vast.ai ($1.92/hr).
Model: RedHatAI/gemma-4-31B-it-FP8-Dynamic, max_ctx=2048.
Engine: vLLM 0.19 with default settings.
vLLM TPU comparison: planned but not yet completed. vLLM TPU installation requires significant disk space beyond what our current 100 GB boot disk supports (vLLM + PyTorch + dependencies ~15 GB on top of model weights). A 200 GB disk instance is needed. We intend to run this comparison on the same v6e-4 hardware for an apples-to-apples TPU benchmark.
Small Model Parallelization speculative, not yet benchmarked
This section is speculation. We have not extensively tested single-GPU parallelization for small models, nor factored all potential strategies. The numbers below are back-of-envelope estimates based on hardware specs and known bottlenecks. Treat them as directional, not measured.
A 4B model like Gemma 4 E4B is heavily memory-bandwidth bound at B=1 decode. The weights are ~5 GB (int8). At B=1, decode is a pure weight-streaming problem: read every weight once, do one MAC per weight, write nothing. The arithmetic intensity is ~1 FLOP/byte -- deep in the bandwidth-bound regime. The GPU's compute (FP8 TOPS) is almost entirely idle.
This means the accelerator with the most memory bandwidth wins, regardless of compute capability. A single H100 with 3.35 TB/s could theoretically stream the 5 GB model in 1.5 ms -- implying ~670 tok/s at B=1 if the forward pass were purely bandwidth-limited. Real decode has overhead (kernel launch, attention, KV cache reads, softmax, RoPE), so practical throughput would be lower. But the gap between our measured 78 tok/s (TPU) and the theoretical ceiling is large.
Projected E4B (4B) scaling: measured + estimated
E4B throughput: measured (solid) vs projected (open)
Measured on TPU v6e-4. GPU projections from HBM bandwidth at 40-50% utilization. INT4 assumes 2.5 GB weights. Batch projections extrapolated from measured scaling curve.
Single-GPU estimates for 4B decode (B=1)
GPU
HBM BW
Weight read time
Theoretical ceiling
Estimated practical
H100 SXM 80GB
3.35 TB/s
1.5 ms
~670 tok/s
~200-350 tok/s
H200 SXM 141GB
4.8 TB/s
1.0 ms
~960 tok/s
~300-500 tok/s
B200 SXM 192GB
8.0 TB/s
0.6 ms
~1,600 tok/s
~500-900 tok/s
Assumptions: 5 GB int8 weights, 42 layers, no batch. "Practical" estimate accounts for ~40-50% bandwidth utilization (kernel launch overhead, attention, KV cache, non-GEMM ops). Real numbers could be higher with aggressive fusion or lower with long-context KV overhead.
No Python runtime = more VRAM for the model
rvLLM's GPU path is a compiled Rust binary with zero Python dependency. This matters because Python-based inference stacks consume significant GPU memory before a single weight is loaded:
Component
Typical VRAM cost
rvLLM
PyTorch + CUDA context
~1.5-2.0 GB
0
Python runtime + GC overhead
~0.3-0.5 GB
0
vLLM scheduler / PagedAttention metadata
~0.5-1.0 GB
0
Duplicate weight buffers (load then convert)
~0.5-2.0 GB
0
NCCL / distributed comm buffers
~0.3-0.5 GB
0
Total framework overhead
3-6 GB
~0.2 GB
rvLLM's only VRAM overhead is the CUDA context itself (~200 MB) and pre-allocated scratch buffers. Measured via nvidia-smi before and after weight loading. Python stacks vary; numbers reflect vLLM 0.19 and HuggingFace transformers on H100.
On an H100 80GB, recovering 3-6 GB of framework overhead means the difference between fitting a model or not, or between B=256 and B=512 batch sizes. For small models where the weights are only 5 GB, framework bloat can exceed the model itself.
What the extra VRAM buys you
Higher batch sizes -- each KV cache entry for E4B at 2048 context is ~2 MB. Recovering 4 GB of framework overhead = ~2,000 additional concurrent sequences. On an H100, this could push peak batch from ~512 to ~2,048+, potentially 2-3x more throughput before OOM.
Longer context per sequence -- at B=128 with 4 GB extra VRAM, each sequence could extend context from 2K to 32K+ tokens without reducing batch size. Critical for agent workflows and long-document tasks.
Model tiling on one GPU -- a 4B INT4 model is ~2.5 GB. On an H100 with 80 GB and no Python overhead, you could tile multiple model replicas in VRAM and round-robin requests across them, effectively multiplying B=1 throughput by the number of replicas. 80 GB / 2.5 GB = 32 replicas theoretically, though KV cache memory limits this in practice.
FP4/INT4 weights -- halving precision to 4-bit halves bandwidth requirement, potentially doubling B=1 throughput. E4B at INT4 = ~2.5 GB, readable in 0.75 ms on H100. Combined with CUDA graph fusion and the VRAM savings from no Python runtime, B=1 decode at 400+ tok/s on a single H100 is plausible but unverified.
Projected single-GPU 4B performance
With no framework overhead eating into VRAM and bandwidth-optimal weight streaming, small models can exploit the full hardware budget. The real lever is batch scaling: at B=128, the same weight read serves 128 sequences, pushing arithmetic intensity to ~128 FLOPs/byte and saturating tensor cores. Our measured E4B numbers show 81x scaling from B=1 to B=128.
Bottom line: a native binary running a 4B model on a B200 at B=1 could plausibly hit 500-900 tok/s -- fast enough for real-time speech synthesis or interactive agents. At B=128, the same B200 could push 50,000+ tok/s. The no-Python VRAM savings make higher batch sizes and longer contexts achievable where Python stacks would OOM. We haven't extensively tested single-GPU parallelization and these are estimates from hardware specs and measured TPU scaling curves. If you're interested in pushing this direction, we'd welcome collaboration.