· AI Infrastructure · 4 min read
Scaling Structural Bias - Pre-training Custom Qwen3 on TPU v6e
An end-to-end guide to orchestrating Custom Qwen3 pre-training on Google Cloud's Trillium TPUs. I dive into modifying the Qwen3 architecture for structured JSON outputs, leveraging XPK for orchestration, and serving the final artifacts with vLLM's high-performance openXLA backend.

For AI infrastructure, 2025 has been defined by the synergy between high-performance custom silicon and modular open-source frameworks. This post provides a definitive guide to building a pre-training pipeline on Google’s TPU v6e (Trillium) using MaxText—the JAX-native stack—and the cutting-edge Qwen3 architecture.
1. Orchestration: Setting up the XPK Cluster
The Accelerated Processing Kit (XPK) is the orchestration layer that sits on top of GKE. It handles “JobSets” - ensuring that if a TPU host fails, the entire training group remains synchronized.
# Provision a GKE cluster optimized for AI Hypercomputer workloads
# --tpu-type v6e-8 specifies the single-host Trillium architecture
# --on-demand ensures capacity for long-running pre-training jobs
python3 xpk.py cluster create \
--cluster qwen3-training-cluster \
--tpu-type v6e-8 \
--num-slices 1 \
--project ${PROJECT_ID} \
--zone us-east1-d \
--on-demand
2. Architecture: Designing a Structurally Biased Qwen3
We are using Qwen3 as our reference. Qwen3 is known for its SwiGLU activation and Grouped Query Attention (GQA). Here, we add a “Structural Bias” layer to force the model into JSON-like outputs.
# src/MaxText/layers/models.py
import flax.linen as nn
from . import attention, linears
from MaxText import common_types
class Qwen3StructuralBlock(nn.Module):
"""A Qwen3-style block with SwiGLU activation and GQA sharded for TPU."""
config: common_types.Config
@nn.compact
def __call__(self, x, mask=None):
# Pre-Norm design: RMSNorm -> Attention -> RMSNorm -> MLP
res = x
x = nn.RMSNorm(epsilon=1e-6)(x)
# attention.Attention is co-designed with XLA for peak TFLOPS on v6e
x = attention.Attention(self.config)(x, mask)
x = x + res
res = x
x = nn.RMSNorm(epsilon=1e-6)(x)
# MLP with intermediate_dim sharding (typically 4x hidden_dim)
x = linears.MlpBlock(intermediate_dim=self.config.mlp_dim, config=self.config)(x)
x = x + res
return x
class MyQwen3Model(nn.Module):
"""Qwen3 architecture with a custom structural bias head."""
config: common_types.Config
@nn.compact
def __call__(self, input_ids, positions, decoder_mask=None):
x = nn.Embed(num_embeddings=self.config.vocab_size, features=self.config.emb_dim)(input_ids)
for _ in range(self.config.num_decoder_layers):
x = Qwen3StructuralBlock(self.config)(x, decoder_mask)
# --- STRUCTURAL MODIFICATION ---
# We insert a Tanh-activated Dense layer before the final projection.
# By initializing this with specific weights, we bias the model to
# generate structured markers (e.g., '{', '}', ':') with higher probability.
x = nn.Dense(self.config.emb_dim, name="structural_head")(x)
x = nn.tanh(x)
logits = nn.Dense(self.config.vocab_size, use_bias=False)(x)
return logits
3. Data: The Grain Pipeline
Grain is the recommended data loader for this stack. It provides deterministic streaming and multi-host resilience.
- The Dataset: We use the Qwen3-36T corpus or C4.
- Structure: Data is tokenized into
ArrayRecordfiles. Each record is a serialized JAX-friendly byte-stream. - Location: Sharded across GCS (e.g.,
gs://my-data/train/shard_*.arrayrecord). - Format: You can use the
MaxText/data_processingscripts to convert raw.jsonlfrom Hugging Face intoArrayRecord.
4. Monitoring: Telemetry with XPK and GKE
Once the container is deployed, XPK provides a specialized observability suite.
Telemetry Channels
- MFU (Model FLOPs Utilization): High MFU (55%+) confirms that your batch size is large enough to saturate the TPU cores.
- Straggler Detection: XPK identifies if one chip is underperforming, allowing you to drain that node before it stalls your global sync.
- TPU Health: Integrated through Cluster Director, providing thermal and network interconnect (ICI) health metrics.
5. Serving: Moving to vLLM
vLLM supports TPU v6e via the OpenXLA/Pallas backend.
- Unscan: Use
generate_param_only_checkpoint.py --force_unroll=Trueto flatten the MaxText training loops. - vLLM Launch:
# Deploying with vLLM on TPU
export VLLM_TARGET_DEVICE="tpu"
python3 -m vllm.entrypoints.openai.api_server \
--model gs://your-bucket/qwen3_hf \
--tensor-parallel-size 8 \
--max-model-len 32768 # High context support
Conclusion: Scaling with Multi-Slice and Cluster Director
When scaling to a Multi-Slice setup (e.g., two v6e-256 slices), communication happens over the Data Center Network (DCN).
- DCN Data Parallelism: Set
dcn_data_parallelism: 2in your config. This ensures weight-sharding (FSDP) stays within the fast local slice, while only gradient updates travel over the slower DCN. - Cluster Director: Acts as the “Air Traffic Control.” It provides Topology-Aware Scheduling, ensuring your slices are physically co-located on the network fabric for minimum latency. It also handles Bill of Health checks, ensuring your job only starts once all 512 chips are verified healthy.



