· AI Infrastructure  · 4 min read

Scaling Structural Bias - Pre-training Custom Qwen3 on TPU v6e

An end-to-end guide to orchestrating Custom Qwen3 pre-training on Google Cloud's Trillium TPUs. I dive into modifying the Qwen3 architecture for structured JSON outputs, leveraging XPK for orchestration, and serving the final artifacts with vLLM's high-performance openXLA backend.

An end-to-end guide to orchestrating Custom Qwen3 pre-training on Google Cloud's Trillium TPUs. I dive into modifying the Qwen3 architecture for structured JSON outputs, leveraging XPK for orchestration, and serving the final artifacts with vLLM's high-performance openXLA backend.

For AI infrastructure, 2025 has been defined by the synergy between high-performance custom silicon and modular open-source frameworks. This post provides a definitive guide to building a pre-training pipeline on Google’s TPU v6e (Trillium) using MaxText—the JAX-native stack—and the cutting-edge Qwen3 architecture.

1. Orchestration: Setting up the XPK Cluster

The Accelerated Processing Kit (XPK) is the orchestration layer that sits on top of GKE. It handles “JobSets” - ensuring that if a TPU host fails, the entire training group remains synchronized.

# Provision a GKE cluster optimized for AI Hypercomputer workloads
# --tpu-type v6e-8 specifies the single-host Trillium architecture
# --on-demand ensures capacity for long-running pre-training jobs
python3 xpk.py cluster create \
    --cluster qwen3-training-cluster \
    --tpu-type v6e-8 \
    --num-slices 1 \
    --project ${PROJECT_ID} \
    --zone us-east1-d \
    --on-demand 

2. Architecture: Designing a Structurally Biased Qwen3

We are using Qwen3 as our reference. Qwen3 is known for its SwiGLU activation and Grouped Query Attention (GQA). Here, we add a “Structural Bias” layer to force the model into JSON-like outputs.

# src/MaxText/layers/models.py
import flax.linen as nn
from . import attention, linears
from MaxText import common_types

class Qwen3StructuralBlock(nn.Module):
  """A Qwen3-style block with SwiGLU activation and GQA sharded for TPU."""
  config: common_types.Config
  
  @nn.compact
  def __call__(self, x, mask=None):
    # Pre-Norm design: RMSNorm -> Attention -> RMSNorm -> MLP
    res = x
    x = nn.RMSNorm(epsilon=1e-6)(x)
    
    # attention.Attention is co-designed with XLA for peak TFLOPS on v6e
    x = attention.Attention(self.config)(x, mask)
    x = x + res
    
    res = x
    x = nn.RMSNorm(epsilon=1e-6)(x)
    # MLP with intermediate_dim sharding (typically 4x hidden_dim)
    x = linears.MlpBlock(intermediate_dim=self.config.mlp_dim, config=self.config)(x)
    x = x + res
    return x

class MyQwen3Model(nn.Module):
  """Qwen3 architecture with a custom structural bias head."""
  config: common_types.Config
  
  @nn.compact
  def __call__(self, input_ids, positions, decoder_mask=None):
    x = nn.Embed(num_embeddings=self.config.vocab_size, features=self.config.emb_dim)(input_ids)
    
    for _ in range(self.config.num_decoder_layers):
      x = Qwen3StructuralBlock(self.config)(x, decoder_mask)
    
    # --- STRUCTURAL MODIFICATION ---
    # We insert a Tanh-activated Dense layer before the final projection.
    # By initializing this with specific weights, we bias the model to 
    # generate structured markers (e.g., '{', '}', ':') with higher probability.
    x = nn.Dense(self.config.emb_dim, name="structural_head")(x)
    x = nn.tanh(x) 
    
    logits = nn.Dense(self.config.vocab_size, use_bias=False)(x)
    return logits

3. Data: The Grain Pipeline

Grain is the recommended data loader for this stack. It provides deterministic streaming and multi-host resilience.

  • The Dataset: We use the Qwen3-36T corpus or C4.
  • Structure: Data is tokenized into ArrayRecord files. Each record is a serialized JAX-friendly byte-stream.
  • Location: Sharded across GCS (e.g., gs://my-data/train/shard_*.arrayrecord).
  • Format: You can use the MaxText/data_processing scripts to convert raw .jsonl from Hugging Face into ArrayRecord.

4. Monitoring: Telemetry with XPK and GKE

Once the container is deployed, XPK provides a specialized observability suite.

Telemetry Channels

  • MFU (Model FLOPs Utilization): High MFU (55%+) confirms that your batch size is large enough to saturate the TPU cores.
  • Straggler Detection: XPK identifies if one chip is underperforming, allowing you to drain that node before it stalls your global sync.
  • TPU Health: Integrated through Cluster Director, providing thermal and network interconnect (ICI) health metrics.

5. Serving: Moving to vLLM

vLLM supports TPU v6e via the OpenXLA/Pallas backend.

  1. Unscan: Use generate_param_only_checkpoint.py --force_unroll=True to flatten the MaxText training loops.
  2. vLLM Launch:
# Deploying with vLLM on TPU
export VLLM_TARGET_DEVICE="tpu"
python3 -m vllm.entrypoints.openai.api_server \
    --model gs://your-bucket/qwen3_hf \
    --tensor-parallel-size 8 \
    --max-model-len 32768 # High context support

Conclusion: Scaling with Multi-Slice and Cluster Director

When scaling to a Multi-Slice setup (e.g., two v6e-256 slices), communication happens over the Data Center Network (DCN).

  • DCN Data Parallelism: Set dcn_data_parallelism: 2 in your config. This ensures weight-sharding (FSDP) stays within the fast local slice, while only gradient updates travel over the slower DCN.
  • Cluster Director: Acts as the “Air Traffic Control.” It provides Topology-Aware Scheduling, ensuring your slices are physically co-located on the network fabric for minimum latency. It also handles Bill of Health checks, ensuring your job only starts once all 512 chips are verified healthy.
Back to Blog

Related Posts

View All Posts »
The Compute-to-Cashflow Gap

The Compute-to-Cashflow Gap

The AI industry is shifting from celebrating large compute budgets to hunting for efficiency. Your competitive advantage is no longer your GPU count, but your cost-per-inference.

AI Quantization and Hardware Co-Design

AI Quantization and Hardware Co-Design

Explore how quantization and hardware co-design overcome memory bottlenecks, comparing NVIDIA and Google architectures while looking toward the 1-bit future of efficient AI model development.

Debugging NCCL Ring Failures

Debugging NCCL Ring Failures

When standard tools report a healthy cluster, but your training is stalled, the culprit is often a broken ring topology. We decode specific NCCL algorithms and debugging flags.