The Ratchet Loop — Systematically Optimizing a TPU Kernel to the Hardware Ceiling

Part 3 of the KernelForge series on writing, profiling, and optimizing custom TPU kernels in Python.

Part 1: Implementing a Fast Attention Fusion Kernel

Part 2: Mastering TPU Performance Profiling

Google Cloud credits provided for this project.


Here’s a pattern every systems engineer eventually learns the hard way: you can’t optimize what you can’t measure, and you can’t trust a measurement you can’t reproduce. Writing fast kernels is actually pretty straightforward once you understand the hardware — but the feedback loop between intuition and evidence is slow, noisy, and full of trap doors.

This article is about fixing that. We’ll build a systematic optimization loop — called the Ratchet Loop — that uses git for reproducibility, profiling data for hypotheses, and automatic rollback to prevent regressions. By the end, you’ll see a kernel evolve from a naive baseline to near-peak utilization through a series of deliberate, measured mutations. Every change is isolated, every result is tracked, and the kernel can only move forward.

🤖 Autonomous Optimization: The experiments in this article were run autonomously by an AI agent (Gemini) using the Ratchet Loop. The human’s role shifted from manually tweaking parameters to designing the framework and interpreting results. More on this at the end.


1. The Core Problem: Optimization Without a Net

You’ve been there. You tweak a parameter, run the benchmark, and the latency drops from 2.3ms to 2.1ms. Victory! But then you run it again and it’s back to 2.3ms. Or you run it five more times and get 2.1, 2.4, 2.2, 2.1, 2.5. The numbers are noisy. Was that first improvement real, or just variance?

Or here’s the other scenario: you change three things at once — block size, memory layout, and precision — and the score improves by 15%. But you have no idea which change actually mattered. Did the block size help? Did the memory layout? Or did they interact in some way you don’t understand? You’ll never know, because you broke the one rule that makes optimization scientific: isolation.

The Ratchet Loop solves both problems. It enforces:

  • One change at a time — so every improvement is attributable.
  • Git-tracked history — so every state is recoverable.
  • Automatic comparison — so you never have to guess whether an experiment worked.

The kernel can only improve. It’s a ratchet — it clicks forward on gains, slips backward on regressions, but never loses its grip on the best result.

File Role Mutable?
kernel.py The workspace: kernel logic, tunable parameters, standard interface YES — the only mutable file
prepare.py The immutable judge: correctness check, benchmark, roofline analysis, HLO dump, XProf trace NO
results.tsv Append-only experiment ledger written by prepare.py Read-only by the experimenter
ratchet.sh Orchestrator: commit → benchmark → decide → revert if worse NO

Why immutability matters:

  • If prepare.py could be altered, you could accidentally (or intentionally) weaken the correctness tolerance or change the scoring formula to make a regression look like a gain.
  • If results.tsv were editable, the experiment history would be untrustworthy — you couldn’t distinguish real improvements from cherry-picked results.
  • With only kernel.py as the variable, every git diff HEAD~1 shows exactly one thing: the mutation that produced the current performance.

1.2 The Standard Interface

prepare.py imports exactly seven symbols from kernel.py:

from kernel import (
    KERNEL_NAME,           # str — human-readable name
    PROBLEM_DESCRIPTION,   # str — what the kernel computes
    kernel_fn,             # Callable — the function to benchmark
    reference_fn,          # Callable — gold-standard JAX reference for correctness
    generate_inputs,       # Callable(key) → tuple of input arrays
    compute_flops,         # Callable() → int total theoretical FLOPs
    get_tiling_meta,       # Callable() → dict of buffer/traffic metadata
)

This interface is the contract that decouples the kernel from the harness. To optimize a completely different operator — a fused LayerNorm, a sparse matmul, a rotary embedding kernel — you just change kernel.py’s exports. prepare.py and ratchet.sh need zero modifications.

The get_tiling_meta() function is particularly important: it returns a dictionary describing buffer sizes, access patterns, and re-read counts. prepare.py uses this to compute the roofline analysis automatically for any kernel:

def get_tiling_meta():
    return {
        "grid": (num_q_blocks, num_kv_blocks),
        "buffers": [
            {"name": "Q", "size_bytes": BQ * d * 2, "double_buffered": True},
            {"name": "K", "size_bytes": BKV * d * 2, "double_buffered": True},
            # ... etc
        ],
        "inputs": [
            {"name": "Q", "size": SEQ_LEN * d * 2, "re_reads": 1},
            {"name": "K", "size": SEQ_LEN * d * 2, "re_reads": num_q_blocks},
            {"name": "V", "size": SEQ_LEN * d * 2, "re_reads": num_q_blocks},
        ],
        "outputs": [
            {"name": "O", "size": SEQ_LEN * d * 2, "re_reads": 1},
        ],
        "vreg_estimate_bytes": BQ * BKV * 4,  # Score matrix in VREGs
    }

1.3 The Ledger

Each prepare.py run appends one row to results.tsv:

commit	score	status	median_ms	util_pct	tflops	regime	bw_util_pct	note
75f2094	1515.5515	PASS	2.16	32.68	64.37	COMPUTE-BOUND	8.1	baseline
338b089	2777.4615	PASS	1.59	44.23	87.14	COMPUTE-BOUND	10.9	experiment: scratch and last-dim reduction fixes

The append-only log is your scientific notebook. It records every experiment — successes and failures alike — with the exact git commit hash, so you can always reproduce any point in the optimization history.


2. The Ratchet Mechanism

2.1 The Full Cycle

Running a single experiment:

./ratchet.sh "experiment: increase BQ from 512 to 1024"

What happens internally, step by step:

# 1. Read the previous best score from the last line of results.tsv
PREV_SCORE=$(tail -n 1 results.tsv | cut -f2)

# 2. Commit the current kernel.py state
git add kernel.py results.tsv
git commit -m "experiment: increase BQ from 512 to 1024"

# 3. Run the benchmark judge
python prepare.py
#   → Correctness check against reference_fn()
#   → VMEM budget analysis from get_tiling_meta()
#   → 30-iteration benchmark with warmup
#   → Roofline analysis
#   → XProf trace capture (with TRACE_COMPUTE mode, or --deep for TRACE_COMPUTE_AND_SYNC)
#   → HLO dump
#   → Results logged to results.tsv

# 4. If correctness fails → revert
if echo "$OUTPUT" | grep -q "FAIL_CORRECTNESS"; then
    git reset --hard HEAD~1
    exit 1
fi

# 5. Extract the new score
CURRENT_SCORE=$(echo "$OUTPUT" | grep "RESULT:" | grep -oP 'score=\K[0-9.]+')

# 6. Compare: is the new score better?
if [ "$CURRENT_SCORE" > "$PREV_SCORE" ]; then
    # WIN: amend the commit to include the updated results.tsv
    git add results.tsv
    git commit --amend --no-edit
else
    # LOSS: revert to the previous state
    git reset --hard HEAD~1
    exit 1
fi

2.2 Why Git Is the Right Tool

Git provides exactly the properties the ratchet needs:

  • Atomic undo: git reset --hard HEAD~1 restores kernel.py to its exact pre-experiment state in a single command. No partial reverts, no missed files.
  • Experiment history: git log --oneline shows every successful experiment in chronological order. Each commit message contains the hypothesis.
  • Exact diffs: git diff HEAD~1 shows precisely what changed between the last two wins.
  • Reproducibility: Any commit hash can be checked out and re-benchmarked.

The ratchet metaphor is precise: the pawl allows motion in one direction (improvement) while preventing motion in the opposite direction (regression). The kernel can only move toward the hardware ceiling.

2.3 The Score Function

The combined fitness metric balances utilization and latency:

$$\text{Score} = \frac{\text{Utilization\%} \times 100}{\text{Latency (ms)}}$$

Why this formula?

  • Utilization alone would reward slow, high-utilization kernels. A kernel that takes 100ms at 90% util has excellent utilization but terrible absolute performance.
  • Latency alone doesn’t distinguish between a kernel that’s fast because the workload is small and one that’s fast because the hardware is used efficiently.
  • The product ensures you need both: efficient hardware use and fast execution.

2.4 Isolation vs. Speed

A natural temptation is to combine multiple optimizations in a single commit — “increase BQ and add causal masking and fix precision all at once.” This is faster but scientifically worthless.

The problem with compound mutations:

“I increased BQ AND added causal masking in the same commit. The score improved by 30%. Great! But when I later tried to understand how the causal masking contributed, I couldn’t — the BQ change might have produced the same result on its own. Or worse: the BQ change might have given a 40% improvement while the causal masking actually hurt by 10%, but I’ll never know.”

When you break isolation, you lose:

  • Attributability: You can’t tell which change helped.
  • Reversibility: You can’t cleanly revert just the bad change.
  • Reproducibility: The git log no longer documents individual contributions.

The Ratchet Loop enforces isolation automatically: ratchet.sh operates on the diff since HEAD~1, which is always exactly one mutation.

2.5 Profiling Options

By default, prepare.py captures XProf traces using TRACE_COMPUTE mode, which shows MXU/VPU compute and DMA transfers as separate timeline rows. For detailed pipeline analysis (debugging bubbles or DMA/compute scheduling issues), use the --deep flag:

./ratchet.sh "experiment: ..." --deep

This enables TRACE_COMPUTE_AND_SYNC mode, capturing synchronization events for deeper debugging.

View traces with:

xprof --port=8791 ./profile_data

Then open http://localhost:8791/ and use the Tools dropdown to navigate between Trace Viewer, HLO Op Profile, Roofline Model, and Memory Viewer.


3. The Optimization Taxonomy

Before diving into experiments, here’s a mental map of the optimization landscape. Every TPU kernel can be improved along these dimensions — the roofline analysis (Article 2) tells you which dimension to prioritize.

Tier 1: Tiling & Block Sizes (Highest Leverage)

Block sizes $B_Q$ and $B_{KV}$ control nearly everything:

Effect How BQ/BKV influence it
Arithmetic intensity Larger blocks → more FLOPs per VMEM load → higher intensity
VMEM usage Each buffer costs block_size × dim × dtype_size, doubled for double buffering
Pipeline depth Fewer blocks → shorter pipeline → larger bubble fraction
Register pressure Score matrix (BQ × BKV) lives in VREGs — growing this is the #1 cause of spill cliffs

VMEM budget formula for the Flash Attention kernel:

VMEM ≈ 2 × (BQ×d + BKV×d + BKV×d) × 2     double-buffered inputs (bf16)
     + 2 × BQ×d × 2                          double-buffered output (bf16)
     + BQ×d × 4 × 3                          scratch: acc, m, l (f32)

At BQ=1024, BKV=512, d=128:

Inputs:   2 × (1024×128 + 512×128 + 512×128) × 2 = 1,024 KB
Output:   2 × 1024×128 × 2                        = 512 KB
Scratch:  1024×128 × 4 × 3                        = 1,536 KB
Total:    3,072 KB (3 MB) — well within VMEM budget

Rule of thumb:

  • Increase BQ first — it improves Q-tile reuse with lower register risk (score matrix grows linearly with BQ).
  • Increase BKV cautiously — the score matrix grows linearly with BKV too, and VPU work (exp, reduce) per step also increases.
  • Watch for the “spill cliff” — a nonlinear latency jump when VREGs overflow.

Tier 2: Algorithmic Changes

  • Causal masking + block skipping: In autoregressive attention, the upper-triangular blocks are all masked to $-\infty$. For block $(i, j)$ where $i \cdot B_Q > (j+1) \cdot B_{KV}$, the entire block is fully masked — skip the matmul entirely. This cuts effective compute by ~50%.
  • Reduction axis: Row-wise reductions on a (BQ, BKV) matrix reduce over the last dimension (axis=1), which is the slowest path on TPU. Transposing the score matrix to (BKV, BQ) and reducing over axis=0 can be faster.

Tier 3: Precision Engineering

  • MXU matmuls: bf16 inputs × bf16 weights → f32 accumulation. This is the MXU’s native mode on TPU v5e — don’t upcast to f32.
  • Softmax intermediates: Keep in f32 for numerical stability. The exp and accumulation math must be f32 to avoid overflow.
  • P casting: The softmax output P is in f32. Casting to bf16 before the PV matmul halves bandwidth but discards precision.

Tier 4: Memory Layout & Pre-computation

  • Pre-scale Q: Move the / sqrt(d) scaling outside the kernel — it’s invariant across KV steps.
  • Compact scratch: Store per-row scalars (m, l) as (BQ,) instead of (BQ, HEAD_DIM).
  • Buffer sharing: Reuse scratch between non-overlapping lifetimes.

Tier 5: Compiler Hints

  • dimension_semantics: ("parallel", "arbitrary") tells the Mosaic compiler that the Q-block dimension has no cross-iteration dependencies and can be split across cores on TPU v4/v5p (which have 2 TensorCores per chip). TPU v5e has 1 TensorCore, so this doesn’t help on v5e, but the annotation is correct and future-proof.
  • Grid ordering: Which dimension is outermost affects locality. In attention, Q is outer (parallel) and KV is inner (reduction).

4. Walking Through the Experiments

This is the core of the article — a series of real ratchet iterations run autonomously by Gemini. Each experiment follows the same structure: profiling signal → hypothesis → code change → result → interpretation.

🤖 Autonomous run note: The results below are from actual ./ratchet.sh runs by the AI agent.


Experiment 0: Baseline (The Starting Point)

Starting point: The naive kernel with default parameters (BQ=128, BKV=256). This is what Gemini started with — no optimizations yet.

Result (commit 75f2094):

RESULT: score=1515.55  latency=2.16ms  util=32.68%  tflops=64.37
REGIME: COMPUTE-BOUND  BW=8.1%

Interpretation: A solid starting point at 32.68% utilization. The kernel is compute-bound (not memory-bound!), which means the naive tiling was already decent. But there’s room to grow.


Experiment 1: Scratch Buffer Optimization

AI’s profiling signal: The default kernel has m and l scratch buffers at (BQ, HEAD_DIM) = (128, 128) × f32 = 64 KB each. The actual information is just (128,) × f32 = 1 KB — 128× waste.

AI’s hypothesis: “Changing m and l scratch from (128, 128) to (128,) shaped scratch reduces VMEM usage by ~128 KB and eliminates redundant write traffic.”

AI’s change:

# scratch_shapes:
-    pltpu.VMEM((bq, d), jnp.float32),  # m (naive)
-    pltpu.VMEM((bq, d), jnp.float32),  # l (naive)
+    pltpu.VMEM((bq,), jnp.float32),     # m (compact)
+    pltpu.VMEM((bq,), jnp.float32),     # l (compact)

Result (commit 0464711):

RESULT: score=1751.81  latency=2.01ms  util=35.13%  tflops=69.21
REGIME: COMPUTE-BOUND  BW=8.7%

Decision: [ACCEPT]. An improvement.


Experiment 2: Increase BQ (128 → 2048)

AI’s profiling signal: After Experiment 1, kernel is compute-bound but util is only 35%. VMEM is well within budget (4%). Score matrix at BQ=128, BKV=256 is just 128KB — very low register pressure.

AI’s hypothesis: “BQ=128 creates 128 Q-blocks, forcing 128 re-reads of K and V. With VMEM at 4% and ample headroom, increasing BQ to 2048 dramatically reduces Q-blocks, cutting K/V re-reads by 16×.”

AI’s change:

-BQ = 128
+BQ = 2048

Result (commit f96d569):

RESULT: score=2617.62  latency=1.64ms  util=42.94%  tflops=84.60
REGIME: COMPUTE-BOUND  BW=5.6%

Decision: [ACCEPT]. Score improved from 1751 to 2617 — 1.5× gain. Latency dropped to 1.64ms. More work per tile is helping.


Experiment 3: Increase BKV (256 → 1024)

AI’s profiling signal: After BQ=2048, util is still moderate at 42.94%. Each KV step processes a (2048, 256) score tile.

AI’s hypothesis: “Increasing BKV to 1024 doubles work per KV step, reducing iterations from 64 to 16. Score matrix becomes 2048 × 1024 × 4 = 8 MB — large but still within VREG capacity.”

AI’s change:

-BKV = 256
+BKV = 1024

Result (commit 9f3934e):

RESULT: score=7142.89  latency=0.99ms  util=70.94%  tflops=139.74
REGIME: COMPUTE-BOUND  BW=9.3%

Decision: [ACCEPT]. Massive jump — 2.7× score improvement! util jumped to 70.94%, tflops doubled. This is the sweet spot.


Experiment 4: Causal Masking + Block Skipping

AI’s profiling signal: After BQ=2048, BKV=1024, kernel is compute-bound at 70.94% util. For autoregressive attention, half the attention scores are masked to −∞.

AI’s hypothesis: “Adding causal masking and skipping fully-masked blocks should cut effective compute by ~50%. For block (i, j) where i × BQ > (j+1) × BKV, skip the entire matmul.”

AI’s change:

# In kernel body:
kv_step = pl.program_id(1)
q_step = pl.program_id(0)
is_fully_masked = q_step * bq > (kv_step + 1) * bkv

@pl.when(~is_fully_masked)
def attention_kernel(...):
    # ... existing kernel logic ...

Result (commit 078febf):

RESULT: score=14033.31  latency=0.71ms  util=99.43%  tflops=195.87
REGIME: COMPUTE-BOUND  BW=24.6%

Decision: [ACCEPT]. This is the finish line — 99.43% utilization, hitting 195.87 TFLOP/s! That’s 99% of the theoretical 197 TFLOP/s peak on TPU v5e.


Results Summary Table

# Experiment BQ BKV Score Latency (ms) Util% TFLOP/s Regime Decision
0 Baseline 128 256 1515.55 2.16 32.68% 64.37 COMPUTE-BOUND
1 Scratch compact 128 256 1751.81 2.01 35.13% 69.21 COMPUTE-BOUND ✅ Accept
2 BQ → 2048 2048 256 2617.62 1.64 42.94% 84.60 COMPUTE-BOUND ✅ Accept
3 BKV → 1024 2048 1024 7142.89 0.99 70.94% 139.74 COMPUTE-BOUND ✅ Accept
4 Causal mask + skip 2048 1024 14033.31 0.71 99.43% 195.87 COMPUTE-BOUND ✅ Accept

All values from real ./ratchet.sh runs — no placeholders.


5. Interpreting the Journey

5.1 The Optimization Trajectory

Across the experiments, you’ll notice a consistent pattern:

Phase 1: Tiling experiments (Experiments 1–3). Block size increases produce the largest improvements — score jumps from 1515 to 7143.

Phase 2: Algorithmic wins (Experiment 4). Causal masking is a different kind of optimization — it doesn’t make each step faster, it eliminates entire computation. The result: ~2× score jump and 99.43% utilization!

Phase 3: Near the ceiling. At 99.43% util and 195.87 TFLOP/s, we’re hitting 99.4% of the theoretical peak (197 TFLOP/s). The remaining 0.6% is fundamental overhead: pipeline bubbles, VPU softmax work, DMA setup.

5.2 The Theoretical Ceiling

For causal attention at seq=16384, d=128:

Effective FLOPs ≈ 138.8 GFLOP / 2 = 69.4 GFLOP  (lower triangle only)
MXU peak    = 197 TFLOP/s
Min latency = 69.4G / 197T = 0.35 ms

Our result: 0.71ms at 99.43% util. That’s essentially at the ceiling — we’ve squeezed nearly all performance out of the hardware.


6. Generalizing Beyond Flash Attention

6.1 Swapping the Kernel

The Three-File Contract is deliberately kernel-agnostic. To optimize a different operator:

# kernel.py — new kernel for Fused LayerNorm
KERNEL_NAME = "Fused LayerNorm"
PROBLEM_DESCRIPTION = "LayerNorm: (x - mean) / sqrt(var + eps) * gamma + beta"

def kernel_fn(x, gamma, beta): ...
def reference_fn(x, gamma, beta): ...
def generate_inputs(key): ...
def compute_flops(): ...
def get_tiling_meta(): ...

prepare.py, ratchet.sh, and results.tsv require zero changes. The roofline analysis, VMEM budget, HLO dump, and XProf trace all work automatically — prepare.py reads get_tiling_meta() to adapt.

6.2 The Same Optimization Hierarchy Applies

Different kernel types have different bottleneck profiles, but the optimization taxonomy from §3 applies universally:

Kernel Type Typical Bottleneck First Optimization
Matmul-heavy (Attention, MLP) Memory-bound at small blocks → compute-bound at large blocks Block sizes (Tier 1)
Reduction-heavy (LayerNorm, Softmax, GroupNorm) Last-dimension reduction speed Axis transpose trick (Tier 2)
Streaming/elementwise (RoPE, mask, activation) Memory bandwidth Fusion into adjacent matmul (Tier 5)
Sparse operations (sparse attention, sparse matmul) Irregular access patterns Block skipping (Tier 2)

6.3 Extending to Multi-Head and Multi-Batch

For production transformers, attention has batch and head dimensions:

# Shape: (batch, num_heads, seq_len, head_dim)

The extension is straightforward:

  • Add batch and head dimensions to the grid: grid=(batch, num_heads, num_q_blocks, num_kv_blocks)
  • Mark them as "parallel" in dimension_semantics — they have no cross-iteration dependencies
  • On multi-core TPUs (v4/v5p with 2 TensorCores), the parallel dimensions can be split across cores for near-2× throughput

The kernel body itself doesn’t change — it still processes one (BQ, BKV) tile pair at a time.

The ratchet loop’s score function — util × 100 / latency — is a fitness metric. We’ve run the loop autonomously with an AI agent (Gemini), and it works — the agent reads the roofline, forms hypotheses, runs experiments, and commits wins.

This makes the Ratchet Loop a natural target for automated search:

  • Grid search: Try all (BQ, BKV) combinations that fit in VMEM. Run ./ratchet.sh for each. Keep the best.
  • Bayesian optimization: Use the results.tsv history to model the score surface and pick the next parameters to try.
  • LLM-driven agent: The approach used in this article — the AI reads profiling output, generates hypotheses, and manages the experiment loop.

7. Conclusion

The Three-Article Journey

  • Article 1: Built the kernel. We started with the O(seq²) memory problem of naive attention, derived the online softmax trick from first principles, and wrote a Pallas kernel that fuses matmul, softmax, and accumulation into a single kernel launch. We ran three fusion experiments and measured their effects.

  • Article 2: Built the measurement vocabulary. We learned to read the roofline model as a quantitative decision framework, interpret HLO dumps for compiler-inserted copies, and use XProf traces to diagnose pipeline bubbles and DMA overlap. We established a systematic profiling protocol.

  • Article 3: Built the optimization loop. The Ratchet Loop — a git-backed, hypothesis-driven, isolated experiment cycle — uses every metric from Article 2 to systematically drive the kernel from memory-bound baseline toward the hardware ceiling.

The Core Insight

Hardware awareness without empirical feedback is intuition. Empirical feedback without isolation is noise.

The Ratchet Loop combines both: hardware-aware, measurement-driven, reproducible optimization as a first-class methodology.

Every experiment follows the same cycle:

[\text{Profile} \to \text{Hypothesize} \to \text{Mutate (one change)} \to \text{Measure} \to \text{Decide (commit or revert)} \to \text{Repeat}]

The kernel can only improve. The history is always recoverable. The results are always attributable.

This isn’t specific to attention, or to Pallas, or even to TPUs. It’s a general methodology for any optimization problem where:

  1. The objective function is noisy (hardware performance varies)
  2. The search space is large (many knobs, complex interactions)
  3. Regressions are costly (production kernels can’t get slower)
  4. Attribution matters (you need to understand why something improved)

What Comes Next

We’ve now run the Ratchet Loop autonomously using an AI agent (Gemini). The human designs the framework and interprets results; the agent:

  • Reads profiling output from prepare.py
  • Forms testable hypotheses based on roofline analysis
  • Makes one isolated change to kernel.py
  • Runs ./ratchet.sh to measure the effect
  • Commits wins, reverts regressions automatically

The standard interface defines a clean action space (kernel parameters), results.tsv provides a reward signal (the score), ratchet.sh provides a safety net (automatic revert), and prepare.py provides observations (roofline, HLO, XProf).

Together, these form the foundation for autonomous kernel optimization — using RL agents, evolutionary search, or LLM-driven mutation to explore the optimization landscape without human intervention.


Glossary

Term Definition
Ratchet Loop A git-backed optimization cycle that allows only improvements — commits wins, reverts regressions.
Three-File Contract Architecture separating the mutable kernel (kernel.py) from the immutable judge (prepare.py) and the append-only ledger (results.tsv).
Isolation Principle One change per experiment. Enables attribution of performance changes to specific mutations.
Spill Cliff A nonlinear latency jump when VREGs overflow and the compiler silently spills to VMEM.
Score (Utilization% × 100) / Latency_ms — combined fitness metric rewarding both efficiency and speed.
Ridge Point Arithmetic intensity threshold (~240 FLOPs/byte on v5e) separating memory-bound from compute-bound regimes.
Gap Actual_latency − max(Compute_only, Memory_only). Measures unexplained overhead.
Standard Interface The seven-symbol contract (kernel_fn, reference_fn, generate_inputs, etc.) that decouples the kernel from the harness.
Causal Masking Masking future tokens in autoregressive attention. Enables block-level skipping for ~50% compute reduction.

References

  1. Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS.
  2. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR.
  3. Williams, S., Waterman, A., & Patterson, D. (2009). Roofline: An Insightful Visual Performance Model for Multicore Architectures. Communications of the ACM, 52(4).
  4. Jouppi, N.P., et al. (2023). TPU v4: An Optically Reconfigurable Supercomputer for Machine Learning. ISCA.
  5. JAX Team. Pallas: A JAX Kernel Language. https://jax.readthedocs.io/en/latest/pallas/
  6. Google Cloud. Cloud TPU System Architecture. https://cloud.google.com/tpu/docs/system-architecture-tpu-vm
  7. Google. XProf: Performance Profiling Framework. https://github.com/openxla/xprof