Skip to content
AxiomLogicaSearch
AI & ML

How FlashAttention works under the hood: IO-aware exact attention and tiling for long sequences

FlashAttention keeps attention exact while reducing HBM traffic by tiling Q/K/V into SRAM and recomputing rather than materializing the N×N attention matrix — yielding linear-memory behavior and major wall-clock gains, but only when the GPU memory hierarchy and tile sizes are exploited correctly.

How FlashAttention works under the hood: IO-aware exact attention and tiling for long sequences
How FlashAttention works under the hood: IO-aware exact attention and tiling for long sequences

What problem FlashAttention solves in transformer attention

Bottom Line: FlashAttention is an IO-aware exact attention algorithm — it produces the same outputs as standard attention up to floating-point arithmetic while dramatically cutting the number of memory reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. The speedup comes from reducing data movement, not from reducing arithmetic, and the memory savings come from never materializing the full N×N attention matrix.

Standard scaled dot-product attention computes (\text{softmax}(QK^T / \sqrt{d})V). That formula is deceptively cheap in FLOPs relative to other transformer components, but its memory access pattern is brutal: producing the N×N score matrix and writing it to HBM before reading it back for the softmax and then again for the weighted sum creates memory traffic that scales quadratically with sequence length. On a 4096-token sequence with 64 attention heads, the intermediate score tensors alone can occupy several gigabytes.

As Tri Dao et al. state in the original paper: "We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM." The reported wall-clock results follow directly from that design: 7.6× speedup on GPT-2 attention, 3× faster GPT-2 training end-to-end, and 2.4× speedup on Long Range Arena benchmarks — all with exact outputs.

How the GPU memory hierarchy shapes attention speed

FlashAttention is faster than standard attention because it keeps most of its arithmetic resident in fast on-chip SRAM rather than repeatedly round-tripping through HBM. The important distinction is not a specific bandwidth number, but the fact that attention kernels become slow when intermediate tensors move back and forth between GPU memory levels. Reducing those transfers is the core of the algorithm.

The two access patterns are fundamentally different:

flowchart LR
    subgraph Naive["Standard Attention (HBM-bound)"]
        direction TB
        Q1[Q in HBM] -->|read| SM1[SM: compute QKᵀ]
        K1[K in HBM] -->|read| SM1
        SM1 -->|write N×N| S1[Score matrix in HBM]
        S1 -->|read| SM2[SM: softmax]
        SM2 -->|write N×N| P1[Prob matrix in HBM]
        P1 -->|read| SM3[SM: PV multiply]
        V1[V in HBM] -->|read| SM3
        SM3 -->|write| O1[Output in HBM]
    end

    subgraph Flash["FlashAttention (SRAM-resident tiles)"]
        direction TB
        Qb[Q block from HBM] -->|load once| SRAM[On-chip SRAM tile]
        Kb[K block from HBM] -->|load once| SRAM
        Vb[V block from HBM] -->|load once| SRAM
        SRAM -->|accumulate local output| Acc[Running accumulator in SRAM]
        Acc -->|write once per row| O2[Output in HBM]
    end

Standard attention writes the N×N matrix to HBM, reads it for softmax, writes it again, then reads it for the V multiplication — four full HBM traversals of the largest tensor in the computation. FlashAttention tiles the computation so that Q, K, and V blocks move from HBM to SRAM once per tile, and the output accumulates in registers without touching HBM again until the tile is finished.

Why the N×N attention matrix becomes the bottleneck

The competitive gap in most FlashAttention explanations is this: engineers focus on FLOP counts because that is how we were taught to profile compute workloads. But for attention with sequences longer than a few hundred tokens, the computation is already bandwidth-bound, not compute-bound. The arithmetic intensity — FLOPs per byte of memory traffic — of naive attention is low enough that the GPU's arithmetic units spend significant time idle waiting for data.

Pro Tip: Profile attention with torch.profiler and examine kernel time versus memory bandwidth utilization. If HBM bandwidth is saturated while SM utilization is low, you are in the IO-bound regime and FlashAttention's tiling will directly address your bottleneck. If both SM and memory utilization are high, you may already be compute-bound and the relative gain from FlashAttention will be smaller.

FlashAttention's theoretical contribution is a proof that its HBM access complexity is (O(N^2 d / M)) where $M$ is SRAM size and $d$ is head dimension, compared to (O(Nd + N^2)) for standard attention — the key point being that the (N^2) term in standard attention's IO grows much faster than FlashAttention's tiled equivalent when sequences are long. Translating the theorem into engineering language: for each doubling of sequence length, standard attention roughly quadruples its HBM traffic for the score matrix, while FlashAttention's traffic grows far more slowly because tiles reuse on-chip data.

What changes when Q, K, and V stay on-chip longer

When a tile of Q, K, and V fits entirely in SRAM, the SM can compute all attention scores for that Q block against all K blocks without writing intermediate results to HBM. The benefits compound: each K and V block, once loaded into SRAM, contributes to the running accumulator for every Q row in the tile before being evicted. The number of HBM reads for Q, K, and V each drops from once-per-operation-step to once-per-tile.

Watch Out: Tile sizes must fit within the available SRAM budget of the target GPU. The FlashAttention paper explicitly notes that on a T4 GPU — whose SRAM is smaller than the A100's — block sizes must be reduced, and the resulting speedup is accordingly lower. Configuring a tile size that spills out of SRAM defeats the entire mechanism: the kernel falls back to more HBM traffic per tile boundary, and you can end up slower than standard attention due to the bookkeeping overhead of the streaming softmax.

The concrete implication for deployment: when you port a FlashAttention kernel to a new GPU family, the tile size parameters must be retuned against that GPU's SRAM capacity per SM.

Tiling Q, K, and V without materializing attention

FlashAttention changes the implementation, not the attention function. Given the same Q, K, V tensors, it produces the same output tensor as (\text{softmax}(QK^T/\sqrt{d})V) up to floating-point arithmetic. What changes is the order in which the computation is executed and where intermediate results live.

flowchart TD
    HBM_Q[Q in HBM] -->|load tile Q_i| SRAM_Q[Q_i in SRAM]
    HBM_K[K in HBM] -->|load tile K_j| SRAM_K[K_j in SRAM]
    HBM_V[V in HBM] -->|load tile V_j| SRAM_V[V_j in SRAM]

    SRAM_Q --> Score[Compute S_ij = Q_i · K_jᵀ / √d in SRAM]
    SRAM_K --> Score
    Score --> Rescale[Update running max m_i, running sum ℓ_i]
    Rescale --> Accum[Rescale existing O_i, accumulate new P_ij · V_j]
    SRAM_V --> Accum

    Accum -->|next j block| HBM_K
    Accum -->|after all j blocks: write O_i| HBM_O[O_i in HBM]

The outer loop iterates over blocks of Q (indexed by $i$); the inner loop iterates over all blocks of K and V (indexed by $j$). For each $(i, j)$ pair, the SM loads (Q_i), (K_j), (V_j) into SRAM, computes the local dot products, and updates a running output accumulator (O_i) using the softmax rescaling described below. After the inner loop completes, the final (O_i) is written to HBM exactly once. The (N \times N) score matrix never exists as a contiguous tensor.

How blockwise computation preserves exact outputs

The mathematical challenge is that softmax is a global operation — to normalize row $i$ of the score matrix, you need the maximum and the sum of all scores in that row. In standard attention, you compute all scores for row $i$ first, then normalize. FlashAttention processes K blocks one at a time, so it cannot know the global row statistics until all K blocks are visited.

The solution is a streaming softmax update. For a single query row, let $m$ be the running maximum score seen so far and (\ell) be the running sum of exponentials. Before processing block $j$, the algorithm holds ((m_{\text{old}}, \ell_{\text{old}}, O_{\text{old}})). After computing the local scores (s_j) for the new K block:

$$ m_{\text{new}} = \max(m_{\text{old}},\; \max(s_j)) $$

$$ \ell_{\text{new}} = e^{m_{\text{old}} - m_{\text{new}}} \cdot \ell_{\text{old}} + \sum_k e^{s_{j,k} - m_{\text{new}}} $$

$$ O_{\text{new}} = \frac{\ell_{\text{old}} \cdot e^{m_{\text{old}} - m_{\text{new}}}}{{\ell_{\text{new}}}} O_{\text{old}} + \frac{1}{\ell_{\text{new}}} \sum_k e^{s_{j,k} - m_{\text{new}}} v_{j,k} $$

When the inner loop over all K blocks is complete, (O_{\text{new}}) matches the standard softmax-weighted sum up to floating-point arithmetic — the rescaling at each step corrects for the partial statistics, and the result is numerically stable because each exponential is normalized by the current block maximum.

Why softmax rescaling avoids the full score matrix

The running max $m$ and running sum (\ell) are scalars (or vectors of length head-dim), not $N$-length arrays. Storing them costs $O(N)$ memory total, whereas storing the full score matrix costs (O(N^2)). The rescaling formula above updates the output accumulator (O_i) in place, applying a correction factor each time a new maximum is discovered, so the partial output is always in a consistent normalized state.

The trade-off is explicit: this method performs more arithmetic than naive attention because it recomputes exponentials twice per block (once to get the local contribution, once to correct the previous accumulator). However:

$$ \text{Extra FLOPs} \propto O(N \cdot B_K) \ll \text{Saved HBM traffic} \propto O(N^2) $$

where (B_K) is the K-block size. For sequences where $N$ is large enough that HBM bandwidth is the bottleneck, this trade-off pays off substantially. For very short sequences where SRAM fits the entire computation trivially, standard and tiled attention are roughly equivalent.

Forward and backward passes in FlashAttention

Training introduces a second challenge beyond forward-pass speed: the backward pass through attention requires gradients with respect to Q, K, and V, which in standard backpropagation means storing the attention probability matrix (P = \text{softmax}(QK^T/\sqrt{d})) from the forward pass. Storing $P$ costs (O(N^2)) activation memory — for a 4096-token sequence with 32 layers and 16 heads, that is the dominant term in training memory consumption.

FlashAttention's backward pass avoids storing $P$ by recomputing it from scratch using the same tiled forward kernel, using only the saved statistics ((m, \ell)) — which are $O(N)$ — and the original $Q$, $K$, $V$ tensors in HBM.

Production Note: Recomputing activations in backward is a clear win when you are memory-bandwidth-bound during training — the common case for long-context workloads. But if your model is compute-bound (e.g., small sequence length, very wide layers, high batch size) and you are trying to maximize MFU, the recomputation adds arithmetic that costs throughput without a proportional memory benefit. Profile your training step with and without FlashAttention to confirm you are in the bandwidth-bound regime before assuming the backward recomputation helps net throughput.

The practical outcome: FlashAttention lets you train transformers with sequence lengths that would otherwise exhaust GPU memory entirely, not just run slowly. A model that required gradient checkpointing at every attention layer to fit in memory can often run without it when FlashAttention replaces the naive kernel, recovering the throughput cost of manual checkpointing.

What gets recomputed in backward and why

The backward pass for attention requires (\partial L / \partial Q), (\partial L / \partial K), (\partial L / \partial V). These gradient expressions involve $P$ — specifically, terms like (dP = dO \cdot V^T) and the softmax Jacobian (\text{diag}(P) - PP^T). In standard backprop, $P$ is read from HBM where it was stored during the forward pass. In FlashAttention, the backward kernel loads the same Q, K, V tiles it used in forward, recomputes the local attention scores and probabilities within SRAM, and immediately uses them to compute the local gradient contributions before evicting the tile.

Pro Tip: This behavior is structurally similar to activation checkpointing (gradient checkpointing in PyTorch), but targeted specifically at the attention matrix. Standard gradient checkpointing recomputes entire layer forward passes at segment boundaries; FlashAttention's recomputation is tighter — it only recomputes attention scores within each tile, which is a smaller recomputation unit with better arithmetic density. In practice this means you often get most of the memory savings of aggressive checkpointing without the full throughput penalty, because the recomputed region fits entirely in SRAM and avoids HBM traffic.

The saved state from the forward pass is minimal: the output $O$ (written to HBM anyway) and the log-sum-exp statistics (\log(\ell) + m) per query position, which are $O(N)$ in total.

Where the memory savings actually come from

The memory reduction has two sources. First, the forward pass never materializes the (N \times N) attention matrix in HBM — its peak HBM allocation for attention is (O(N \cdot d)) for Q, K, V, and O, plus $O(N)$ for the softmax statistics. Second, the backward pass eliminates the stored $P$ matrix that standard backprop requires.

The key point is the IO savings across the full forward-backward cycle: by avoiding the intermediate attention matrix and recomputing local probabilities as needed, FlashAttention keeps the working set small enough to fit in on-chip memory and reduces transfers between HBM and SRAM.

Where FlashAttention wins and where it can lose

The speedups reported in the original FlashAttention paper are concrete and hardware-specific:

Benchmark Standard Attention FlashAttention Speedup
GPT-2 attention kernel baseline 7.6× faster
GPT-2 training (end-to-end) baseline 3× faster
Long Range Arena baseline 2.4× faster
T4 vs A100 tile efficiency A100 headroom T4 reduced Smaller gain on T4

The 7.6× attention kernel speedup and 3× training speedup are not the same number because attention is one component among many in a full training step — other costs (linear layers, optimizer, communication) dilute the kernel-level gain. The 2.4× Long Range Arena result reflects the full-pipeline benefit including the ability to train at longer sequence lengths that standard attention could not support within memory limits.

Why longer sequences benefit the most

The IO savings from FlashAttention grow with sequence length because the ratio of eliminated HBM traffic to total HBM traffic increases as $N$ grows. For short sequences, the score matrix is small enough that the HBM round-trips are cheap in absolute terms. For long sequences, the score matrix becomes the dominant term and tiling delivers a proportionally larger benefit.

The Long Range Arena benchmark specifically targets long-sequence tasks and shows 2.4× speedup precisely because those tasks expose the quadratic memory bottleneck most severely.

Watch Out: Sequence length is not the only variable. Head dimension $d$ and block size interact: a large $d$ means each tile covers fewer Q positions for the same SRAM budget, reducing the reuse benefit per HBM load. Similarly, for very small sequences (under ~512 tokens), kernel launch overhead and CUDA synchronization can blunt or eliminate the gain. Always benchmark on your actual sequence distribution rather than assuming uniform gains across all input shapes.

When tiling and recomputation are not enough

FlashAttention-2 achieves only 35% utilization on the H100 GPU. The reason is that the H100 introduced new hardware capabilities that FlashAttention-2's kernel design does not exploit. When the attention computation is IO-bound on HBM bandwidth, tiling is sufficient. But the H100's HBM bandwidth improved faster than the arithmetic bottleneck in some workloads shifted, leaving compute-bound regimes where the kernel's synchronization strategy and warp scheduling become the limiting factor.

Pro Tip: Distinguish bandwidth-bound from compute-bound attention by examining the roofline model for your target GPU. On an A100, attention with long sequences typically sits below the arithmetic intensity threshold for the roofline, meaning it is bandwidth-bound and FlashAttention's IO reduction directly translates to speedup. On an H100 with its higher compute-to-bandwidth ratio, the arithmetic intensity threshold shifts — some workloads cross into the compute-bound regime, and kernel-level optimizations introduced in FlashAttention-3 matter more than pure tile sizing.

FlashAttention-3 addresses this by targeting H100 specifically and achieves up to 75% GPU utilization — more than double FlashAttention-2's 35% on the same hardware. The IO-aware tiling remains, but the kernel strategy changes substantially to exploit H100's execution model.

How FlashAttention changed later attention kernels

"FlashAttention (and FlashAttention-2) pioneered an approach to speed up attention on GPUs by minimizing memory reads/writes, and is now used by most libraries to accelerate Transformer training and inference." PyTorch's torch.nn.functional.scaled_dot_product_attention dispatches to FlashAttention kernels in supported configurations. The attention kernel went from a research artifact to a default in under two years.

The version lineage reflects a separation between the core algorithmic idea — IO-aware tiling with streaming softmax — and the CUDA kernel engineering required to maximize utilization on each GPU generation:

Attribute FlashAttention (v1) FlashAttention-2 FlashAttention-3
Core idea IO-aware tiling, streaming softmax Same Same
Target hardware A100, V100 A100, H100 (partial) H100 (primary)
H100 GPU utilization ~35% (estimated baseline) ~35% Up to 75%
Key kernel change Tiling + online softmax Improved parallelism, reduced non-matmul FLOPs Async pipelining and asynchronous data movement
PyTorch integration Manual / xFormers scaled_dot_product_attention dispatch In progress via upstream SDPA
Backward recomputation Yes Yes Yes

What stayed the same across versions

The IO-aware insight — minimize HBM reads and writes by keeping Q, K, V tiles in SRAM and streaming outputs — is the invariant across all three releases. The streaming softmax with running max and sum updates is identical in principle across versions. The backward recomputation strategy (avoid storing $P$, recompute from saved statistics) persists unchanged.

Pro Tip: When reading kernel source across FlashAttention versions, the tiling_params and block size selection logic is where the IO-aware insight lives concretely. If you are adapting a FlashAttention kernel to a new accelerator, start there — the streaming softmax logic above it is stable, but block sizes and memory layout must be retuned per SRAM capacity and memory access latency on the target device.

The stability of the core idea across versions also means that the IO complexity proof from the original paper applies to all three versions at the algorithmic level. The differences between versions are engineering optimizations within that complexity bound, not changes to the bound itself.

What changed in kernel strategy and practical deployment

FlashAttention-3's kernel strategy diverges meaningfully from its predecessors specifically because the H100 hardware changed what the bottleneck is. Where FlashAttention and FlashAttention-2 primarily targeted HBM bandwidth reduction, FlashAttention-3 targets warp-level compute utilization on the H100's expanded Tensor Core throughput. It uses asynchronous data movement and overlapping pipelines to keep the hardware busy in ways earlier kernels could not.

Watch Out: Do not assume that FlashAttention-3 on an H100 behaves identically to FlashAttention-2 on an A100 in terms of performance characteristics. The relative speedup over a naive baseline differs by GPU family, by precision (FP16 versus BF16 versus FP8), by sequence length, and by head dimension. Benchmarking FlashAttention-2 on A100 does not predict FlashAttention-3 behavior on H100 — the kernel design changes affect which workloads hit the theoretical maximum and which hit synchronization or launch overhead first. Always profile on the target hardware with the target precision before committing to a deployment configuration.

FlashAttention-3 is not a universal drop-in replacement without re-verification; its performance characteristics depend on CUDA version compatibility, driver support for the required primitives, and whether your inference or training framework has integrated the dispatch path.

FAQ on exact attention, tiling, and memory traffic

Is FlashAttention exact or approximate? FlashAttention is exact. The streaming softmax with rescaling described above produces outputs numerically equivalent to standard (\text{softmax}(QK^T/\sqrt{d})V), up to the same floating-point rounding that any GPU attention kernel exhibits. There is no approximation of the attention scores, no sparsity, and no low-rank decomposition.

Why is FlashAttention faster than standard attention? Standard attention is IO-bound, not compute-bound, for sequences longer than a few hundred tokens. FlashAttention eliminates most of the HBM traffic by keeping arithmetic on-chip in SRAM. The speedup is proportional to how much of the wall-clock time was consumed by HBM reads and writes — which for long sequences is the dominant term.

Does FlashAttention change the attention algorithm or just the implementation? It changes only the implementation — the execution order and memory layout. The mathematical function computed is identical to standard attention. This is why existing model weights, numerical outputs, and trained checkpoint files are compatible with FlashAttention without modification.

How does FlashAttention handle long sequences? By tiling Q, K, V into blocks that fit in SRAM and using streaming softmax to accumulate exact outputs tile by tile, FlashAttention avoids allocating the (O(N^2)) score matrix entirely. Memory usage for the attention computation scales as (O(N \cdot d)), making sequences that would OOM with standard attention tractable with FlashAttention.

Where does FlashAttention not help? For short sequences (under ~512 tokens), kernel launch overhead and CUDA synchronization can eliminate the gain. For compute-bound workloads on H100-class hardware, FlashAttention-2's tiling alone is insufficient — which is the direct motivation for FlashAttention-3's kernel redesign.

Bottom Line: FlashAttention's performance comes entirely from reducing memory traffic between HBM and SRAM — not from changing the mathematical definition of attention, not from reducing FLOPs, and not from approximating outputs. The IO-aware tiling with streaming softmax is the mechanism; sequence length is the main lever that determines how large the benefit is.

Sources & References

Production Note: The arXiv papers and official kernel documentation below are the primary sources for the claims in this article. Secondary summaries and blog posts often omit the IO complexity analysis that is the core technical contribution. When evaluating competing claims about FlashAttention behavior, consult the paper's Section 3 (IO complexity analysis) and the official kernel release notes directly.


Keywords: FlashAttention, FlashAttention-2, FlashAttention-3, HBM, SRAM, NVIDIA A100, NVIDIA H100, PyTorch scaled_dot_product_attention, CUDA, transformer attention, softmax rescaling, streaming softmax, IO complexity, Long Range Arena, GPT-2

Was this guide helpful?

The weekly brief.

One email each Sunday with what we tested, what we'd buy, and what to skip. No filler.

Share: X · LinkedIn · Reddit