At a glance: install and benchmark prerequisites
At a Glance: Time: 15–30 min with Ninja on a 64-core machine · Prereqs: Linux, CUDA toolkit, PyTorch 2.2+,
packaging,psutil,ninja· Hardware: Ampere/Ada/Hopper for FlashAttention-2; H100/H800 required for FlashAttention-3 beta · RAM: 96 GB+ recommended for unconstrained builds; tuneMAX_JOBSbelow that threshold
FlashAttention installation is gated by a specific intersection of GPU architecture, CUDA toolkit version, PyTorch version, and build toolchain. Getting any one of these wrong produces either a silent fallback to a different attention kernel — invalidating your benchmark — or a hard build failure. The official repository requires a CUDA toolkit or ROCm toolkit, PyTorch 2.2+, the packaging Python package, and Linux as the primary supported OS. FlashAttention-3 beta narrows the hardware requirement further: it targets Hopper-class GPUs (H100/H800) with CUDA 12.3 as the minimum and CUDA 12.8 recommended for peak performance.
The build toolchain matters more than most installation guides acknowledge. Without Ninja, building flash-attn on a 64-core machine takes approximately two hours. With Ninja, the same machine finishes in 3–5 minutes. On systems with less than 96 GB of RAM, an unconstrained parallel build can exhaust memory mid-compilation; MAX_JOBS must be lowered to prevent this. These are not edge cases — they are the two most common sources of avoidable pain in a flash-attn source build.
Prerequisites and compatibility matrix
The table below consolidates the version and architecture requirements before you run a single install command.
| Requirement | FlashAttention-2 (CUDA) | FlashAttention-3 beta |
|---|---|---|
| GPU architectures | Ampere, Ada, Hopper (A100, RTX 3090/4090, H100) | Hopper only (H100/H800) |
| CUDA minimum | Official repo requires a CUDA toolkit; match it to your PyTorch build | 12.3 minimum |
| CUDA recommended | Toolkit version should match the installed PyTorch CUDA runtime | 12.8 |
| PyTorch minimum | 2.2+ | 2.2+ |
| OS | Linux (primary); Windows limited | Linux |
| Dtype (forward) | FP16, BF16 | FP16, BF16, FP8 |
| Dtype (backward) | FP16, BF16 | FP16, BF16 |
| Head dim max | 256 | 256 |
| Head dim >192 backward | A100/A800 or H100/H800 only | H100/H800 |
| Head dim 256 backward on consumer GPUs | Only with dropout disabled (≥ v2.5.5) | N/A |
| ROCm support | Yes (separate backends) | No |
As the official README states: "Requirements: CUDA toolkit or ROCm toolkit. PyTorch 2.2 and above. packaging Python package (pip install packaging)." The psutil and ninja packages are not listed as hard requirements but are effectively mandatory for a build that completes in a reasonable time.
Which GPUs and CUDA versions are actually supported
FlashAttention-2 with the CUDA backend supports Ampere, Ada, and Hopper GPUs. The README explicitly names A100, RTX 3090, RTX 4090, and H100 as representative examples. Turing GPUs (T4, RTX 2080) require the separate flash-attention-turing repository.
FlashAttention-3 beta is Hopper-exclusive. The FlashAttention-3 paper (arXiv:2407.08608) reports speedups of 1.5–2.0× over FlashAttention-2 on H100 GPUs, with FP16 reaching up to 740 TFLOPs/s (75% utilization) and FP8 reaching close to 1.2 PFLOPs/s. These numbers are architecture-specific to H100 and do not transfer to Ampere or Ada hardware under the FA-3 kernel path. CUDA 12.8 is the recommended toolkit version for realizing that peak throughput; running FA-3 on CUDA 12.3 is functional but may leave performance on the table.
Watch Out: FlashAttention-3 beta does not run on Ampere or Ada GPUs — not in a degraded mode, but not at all in the documented path. If you attempt to benchmark FA-3 on an A100, you are either running FA-2 silently or hitting an error. Confirm your GPU SM version before attributing any FA-3 performance claims. Additionally, head dim > 192 backward requires A100/A800 or H100/H800; running training on a consumer GPU (e.g., RTX 4090) with a head dimension of 256 and dropout enabled will fail or silently degrade from flash-attn 2.5.5 forward.
Head-dim limits, dtype support, and backward-pass edge cases
FlashAttention-2 supports all head dimensions up to 256 for the forward pass, but the backward pass has GPU-conditional limits that are frequently misread.
| Head dim | Forward (all supported GPUs) | Backward (Ampere/Ada consumer) | Backward (A100/A800/H100/H800) |
|---|---|---|---|
| ≤ 128 | ✅ FP16, BF16 | ✅ | ✅ |
| 160, 192 | ✅ FP16, BF16 | ❌ (fails) | ✅ |
| 256 | ✅ FP16, BF16 | ✅ only if dropout=False (≥ v2.5.5) | ✅ |
The README documents this directly: "Head dim > 192 backward requires A100/A800 or H100/H800." For head dim 256: "Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5."
BF16 requires Ampere, Ada, or Hopper — it is unavailable on Volta or Turing. FP8 is a FlashAttention-3 forward-only feature; the FA-3 paper validates that FP8 FlashAttention-3 achieves lower numerical error than a baseline FP8 attention on Hopper hardware, but FP8 backward is not yet supported in the beta. Do not assume FP8 availability on any non-Hopper GPU or in any FA-2 path.
Install the required build tools and Python packages
Install Python dependencies and a compatible PyTorch build before touching flash-attn itself. The order matters: PyTorch must already be present with a matching CUDA runtime because flash-attn links against it at build time.
# Step 1: Install a CUDA 12.x-compatible PyTorch (adjust cu128 to match your toolkit)
$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
# Step 2: Install the required Python build dependencies
$ pip install packaging psutil ninja
With these four packages in place, you have the minimum dependency set the official repository requires. Do not proceed to the flash-attn install until the PyTorch import works and torch.cuda.is_available() returns True.
Verify the CUDA toolkit and compiler toolchain first
Toolkit/runtime mismatches are the root cause of the majority of build failures. Before compiling, confirm that the nvcc version, the system CUDA toolkit, and the CUDA version embedded in the installed PyTorch wheel all agree enough to link cleanly.
# Check the system CUDA toolkit version
$ nvcc --version
# Check the CUDA version PyTorch was compiled against
$ python -c "import torch; print(torch.version.cuda)"
# Check that PyTorch can see the GPU
$ python -c "import torch; print(torch.cuda.get_device_name(0))"
# Check CUDA toolkit path (should point to the same major version)
$ echo $CUDA_HOME
$ ls $CUDA_HOME/bin/nvcc
The local toolkit and the PyTorch wheel/runtime CUDA version must be compatible before compilation. A mismatch between the local toolkit and the wheel's embedded runtime is the most common build-time failure mode.
Pro Tip: The CUDA version in a PyTorch wheel (e.g.,
cu128in the wheel filename) reflects the runtime the wheel was compiled against, not the maximum CUDA toolkit version you can use. Install the CUDA toolkit matching that tag exactly. If your system has CUDA 12.3 but your PyTorch wheel iscu128, the flash-attn C++ extension will link against the wrong headers and fail with cryptic symbol errors.
Note that Blackwell-class GPUs (SM 120, RTX 5000 series) are not in the default SM architecture list as of the current repo state. Issue #2535 documents that flash-attn does not build or run on RTX 5070 Ti out of the box, with the error message explicitly naming the missing SM_120 architecture.
Install Ninja and tune MAX_JOBS for the available RAM
Ninja is the single highest-leverage build optimization available. The repository README is unambiguous: "Without ninja, building flash-attn takes 2 hours on a 64-core machine. With ninja, building flash-attn takes 3-5 minutes on a 64-core machine."
# Install Ninja via pip (preferred — guarantees the Python build system finds it)
$ pip install ninja
# Verify Ninja is on PATH
$ ninja --version
# On memory-constrained hosts (< 96 GB RAM), cap parallel jobs
# A reasonable starting point: 1 job per ~4 GB of RAM
$ export MAX_JOBS=8 # adjust to your system's available RAM
Production Note: The 3–5 minute build time is specific to a 64-core machine with Ninja and adequate RAM. On an 8-core machine or a cloud instance with 32 GB RAM, expect 20–40 minutes even with Ninja. Setting
MAX_JOBStoo high on a low-RAM host causes the kernel to OOM-kill compiler processes mid-build, producing misleading errors that look like source code bugs rather than resource exhaustion. Start conservatively (8–16 jobs) and increase only if the build completes successfully.
Build and install FlashAttention from source on Linux
The documented install path from the official repository uses --no-build-isolation, which tells pip to use the already-installed PyTorch and CUDA headers rather than pulling a fresh build environment. This flag is required — without it, pip may resolve a different PyTorch version or miss the local CUDA toolkit.
# Primary install command — requires packaging, psutil, ninja, and PyTorch already installed
$ pip install flash-attn --no-build-isolation
If you need a specific version or the wheel for your environment does not exist on PyPI:
# Build directly from source after cloning
$ git clone https://github.com/Dao-AILab/flash-attention.git
$ cd flash-attention
$ python setup.py install
# Environment variables relevant to the source build
MAX_JOBS: "8" # Reduce on < 96 GB RAM hosts
CUDA_HOME: "/usr/local/cuda-12.8" # Point to the correct toolkit version
TORCH_CUDA_ARCH_LIST: "8.0;8.6;9.0" # Explicitly list target SM architectures
The TORCH_CUDA_ARCH_LIST variable is worth setting explicitly if you are building for a specific GPU fleet. Leaving it unset causes the build system to compile for every supported SM version, which adds significant time. For an H100-only cluster, 9.0 alone is sufficient.
Use the official wheel or source build for your exact CUDA stack
Prebuilt wheels exist for common (PyTorch, CUDA, Python, ABI) combinations, but coverage is incomplete. As documented in issue #2299, there is no official prebuilt wheel for every PyTorch release — community wheels fill some gaps (e.g., flash_attn-2.8.3+cu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64), but these are not guaranteed to match your environment's ABI settings.
Choose prebuilt wheels when:
- The wheel's CUDA tag (e.g., cu128), PyTorch tag (e.g., torch2.6), Python version, and cxx11abi flag all match your environment exactly.
- You cannot afford a 20–40 minute source build in a CI or deployment pipeline.
- You are not targeting a non-standard SM architecture.
Choose source compilation when: - No matching official wheel exists for your PyTorch/CUDA/Python combination. - You need a specific SM architecture list (e.g., SM 90 only for an H100 fleet). - You are building FlashAttention-3 beta features not yet in a stable wheel release. - You need to control compiler flags for reproducibility.
Do not benchmark a wheel build against a source build without confirming that both activated the same kernel paths — differences in compiler optimization flags can produce materially different throughput numbers.
What to do on Windows, ROCm, or non-Hopper systems
Watch Out: Windows is not a supported environment in the official source-build documentation. Issue #2535 explicitly documents DLL PATH conflicts between the Windows CUDA Toolkit installation and flash-attn's build process. If you must run on Windows, expect manual environment variable surgery and no guarantee of a working build. WSL2 with a Linux image is the supported workaround.
ROCm: The repository documents ROCm support with two separate backends. ROCm builds follow a different dependency chain than CUDA builds, and CUDA benchmarking guidance does not transfer. If your infrastructure runs AMD GPUs, treat the ROCm path as a separate installation track and benchmark it independently.
Non-Hopper systems and FA-3: FlashAttention-3 beta is Hopper-only. An A100 (Ampere, SM 80) cannot run the FA-3 kernel path. Do not benchmark FlashAttention-3 on Ampere and report the result as a FlashAttention-3 number — you will be timing FlashAttention-2 or a fallback kernel. Turing GPUs (T4, RTX 2080) require the separate flash-attention-turing repository, which covers only a core feature subset.
Run a correctness check before you benchmark
Timing a kernel that is silently falling back to PyTorch's standard attention implementation produces numbers that look plausible but are meaningless for evaluating FlashAttention. Run a functional check first.
import torch
from flash_attn import flash_attn_func
# Minimal smoke test: batch=2, heads=8, seq_len=512, head_dim=64, dtype=fp16
# Requires CUDA device — must run on Ampere/Ada/Hopper
batch, heads, seqlen, head_dim = 2, 8, 512, 64
dtype = torch.float16
device = "cuda"
q = torch.randn(batch, seqlen, heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch, seqlen, heads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch, seqlen, heads, head_dim, device=device, dtype=dtype)
out = flash_attn_func(q, k, v, causal=True)
# Verify output shape and that no NaN/Inf leaked through
assert out.shape == (batch, seqlen, heads, head_dim), f"Shape mismatch: {out.shape}"
assert not out.isnan().any(), "NaN detected in output — check dtype and GPU compatibility"
assert not out.isinf().any(), "Inf detected in output — check head_dim and GPU constraints"
print("Smoke test passed. Shape:", out.shape, "| dtype:", out.dtype)
Expected output: Smoke test passed. Shape: torch.Size([2, 512, 8, 64]) | dtype: torch.float16. Any import error here points to a build failure. A shape assertion error points to an API version mismatch. NaN output with BF16 on Turing-class hardware indicates an unsupported dtype/architecture combination.
Confirm backend selection and tensor shapes
Before timing, verify that the kernel dispatching to the intended FlashAttention backend matches your GPU, dtype, and head dimension. This step is what most benchmark guides omit, and it is why published numbers are frequently not reproducible.
import torch
from flash_attn.flash_attn_interface import flash_attn_func
import flash_attn
# Print the installed version and CUDA backend info
print(f"flash-attn version : {flash_attn.__version__}")
print(f"PyTorch version : {torch.__version__}")
print(f"CUDA version : {torch.version.cuda}")
print(f"GPU : {torch.cuda.get_device_name(0)}")
print(f"GPU SM arch : sm_{torch.cuda.get_device_capability(0)[0]}{torch.cuda.get_device_capability(0)[1]}")
# Confirm the tensor configuration you intend to benchmark
head_dim = 128
dtype = torch.bfloat16 # BF16 requires Ampere/Ada/Hopper
causal = True
dropout = 0.0 # Dropout=0 required for head_dim=256 backward on consumer GPUs
print(f"\nBenchmark config:")
print(f" head_dim : {head_dim}")
print(f" dtype : {dtype}")
print(f" causal : {causal}")
print(f" dropout : {dropout}")
The SM arch printout tells you whether you are on Hopper (SM 90), Ada (SM 89), or Ampere (SM 80/86). Cross-reference this against the head-dim compatibility table before running backward-pass benchmarks.
Benchmark FlashAttention without misleading results
Benchmark discipline for FlashAttention requires holding GPU architecture, dtype, head_dim, causal mode, dropout, and sequence length constant across all compared runs. The FlashAttention-3 paper reports 1.5–2.0× speedups and up to 740 TFLOPs/s in FP16 on H100 hardware specifically — these numbers should not be quoted for A100 runs or compared against FA-3 results on different hardware without explicit labeling.
The table below shows the fields every benchmark record must capture to be reproducible.
| Field | Why it matters | Example value |
|---|---|---|
| GPU model | Kernel path and throughput ceiling differ by architecture | H100 SXM5 |
| flash-attn version | Kernel code changes between versions | 2.8.3 |
| dtype | FP16 and BF16 achieve different FLOP rates | BF16 |
| head_dim | Determines kernel variant; backward has GPU constraints above 192 | 128 |
| seq_len | FlashAttention's memory advantage grows quadratically with seq_len | 4096 |
| batch size | Affects GPU occupancy | 4 |
| causal | Causal masking halves effective compute load | True |
| dropout | Changes kernel path for head_dim 256 on consumer GPUs | 0.0 |
| warmup iters | First iterations include JIT compilation; exclude from timing | 10 |
Always use at least 10 warmup iterations before measuring. The first kernel invocation includes CUDA graph compilation or JIT overhead that can add 100 ms+ to a measurement that the steady-state kernel completes in under 1 ms.
Measure the right metrics for your workload
Training workloads care about backward throughput (tokens/sec) and peak memory; inference workloads care about time-to-first-token (TTFT) and sustained decode throughput. Report them separately — a single "speed" number conflates two different performance regimes.
| Config | seq_len | dtype | Metric | FlashAttention-2 (A100) | FlashAttention-3 (H100) |
|---|---|---|---|---|---|
| Forward only | 2048 | FP16 | TFLOPs/s | H100 paper-backed figures only | H100 paper-backed figures only |
| Forward only | 8192 | FP16 | TFLOPs/s | H100 paper-backed figures only | H100 paper-backed figures only |
| Forward+Backward | 2048 | BF16 | TFLOPs/s | H100 paper-backed figures only | H100 paper-backed figures only |
| Forward only | 2048 | FP8 | TFLOPs/s | N/A | H100 paper-backed figures only |
FA-3 headline numbers from arXiv:2407.08608: FP16 up to 740 TFLOPs/s, FP8 close to 1.2 PFLOPs/s on H100. FA-2 A100 figures are representative of published benchmarks; measure your own for authoritative comparisons.
Memory use should always be measured with the same dtype and head_dim as the timing run. FlashAttention's core advantage — O(N) memory in sequence length versus O(N²) for standard attention — becomes measurable above seq_len ≈ 2048; benchmarking at seq_len=512 will show minimal memory difference and may make FlashAttention look slower than standard attention due to kernel launch overhead.
Avoid apples-to-oranges comparisons with PyTorch SDPA
PyTorch's scaled_dot_product_attention (SDPA) can dispatch to multiple backends including FlashAttention, cuDNN, and an efficient math kernel depending on dtype, head_dim, and device capability. Benchmarking your custom flash_attn_func call against torch.nn.functional.scaled_dot_product_attention without controlling the SDPA dispatch is not a fair comparison — SDPA may itself route to flash-attn internally.
Choose direct flash_attn_func vs. PyTorch SDPA comparison when:
- You have explicitly disabled all SDPA backends except the one you are testing (torch.backends.cuda.enable_flash_sdp(False) etc.) and want to measure the raw kernel difference.
- You are quantifying the overhead of PyTorch's dispatch layer for latency-sensitive inference.
Choose FlashAttention vs. cuDNN attention when: - Your inference stack already uses cuDNN and you want to determine whether switching kernel implementations is worth the engineering cost. - Both paths support the same dtype and head_dim on the same GPU.
Choose FlashAttention vs. Triton attention kernel when: - You are building a custom kernel and need to understand where cycles are spent relative to a reference implementation. - Both kernels are compiled for the same SM architecture with equivalent optimization flags.
The controlling variables — dtype, causal flag, dropout rate, head_dim, and whether the backward pass is included — must be held constant across all compared runs. Compiler/build settings (source vs. wheel, TORCH_CUDA_ARCH_LIST) should also be reported when publishing benchmark results, since available kernel variants can differ between builds.
Common failure modes and exact fixes
Build and runtime failures almost always fall into one of five categories: missing Ninja (compile-time), CUDA version mismatch (compile-time or import-time), unsupported GPU architecture (runtime or build-time), head-dim/backward constraint violations (runtime), and dtype/backend mismatch (runtime, often silent).
Watch Out: The most dangerous failure mode is silent fallback — flash-attn imports successfully, your code runs, and your benchmark produces numbers, but the actual kernel being executed is not FlashAttention. This happens when dtype, head_dim, or GPU constraints are not satisfied and the library falls back to a standard attention path. Always run the backend confirmation code in the "Confirm backend selection" section before treating any timing number as a FlashAttention result.
| Error symptom | Root cause | Version / constraint | Fix |
|---|---|---|---|
| Build takes 2+ hours | Ninja not installed | All versions | pip install ninja before building |
nvcc fatal: unsupported arch |
SM not in arch list | Blackwell (SM 120), others | Set TORCH_CUDA_ARCH_LIST explicitly |
undefined symbol at import |
CUDA toolkit / wheel version mismatch | All versions | Match nvcc --version with torch.version.cuda |
| OOM during compile | Too many parallel jobs | < 96 GB RAM hosts | export MAX_JOBS=8 (or lower) |
RuntimeError on backward at head_dim=192 |
Wrong GPU class | FA-2, all versions | Use A100/A800 or H100/H800 for backward at head_dim > 192 |
| NaN output with dropout + head_dim=256 | Consumer GPU + dropout constraint | FA-2 < 2.5.5 | Upgrade to ≥ 2.5.5 and set dropout=0.0 |
| DLL load error on Windows | PATH/DLL conflicts | All versions | Use Linux or WSL2 |
| Windows CUDA Toolkit PATH conflict | Windows source-build environment | Official docs | Use Linux or WSL2 and align PATH manually if testing locally |
Ninja is missing or build times explode
If ninja --version returns command not found and you start a flash-attn source build, you are committing to an hours-scale compile. The official README is explicit: "Install ninja if you don't have it. It's much faster using ninja than without. Without ninja, building flash-attn takes 2 hours on a 64-core machine. With ninja, building flash-attn takes 3-5 minutes on a 64-core machine."
Pro Tip: Install Ninja via
pip install ninjarather than the system package manager. The pip version is found by Python's build system automatically without any PATH configuration. If you install Ninja viaaptoryumbut Python's build system cannot find it (becauseNINJAenv var is unset or the binary is not on the build PATH), the speed benefit disappears. After installing, confirm withpython -c "import ninja; print(ninja.BIN_PATH)". On machines with fewer than 96 GB RAM, setMAX_JOBSto a value where (MAX_JOBS × ~4 GB per compiler process) stays within available RAM — OOM mid-build produces misleading errors that waste diagnostic time.
Head-dim and backward-pass errors on the wrong GPU
A training run with head_dim=192 or head_dim=256 will fail or silently degrade on consumer Ampere/Ada GPUs unless specific conditions are met. The constraints are:
| head_dim | Operation | RTX 4090 / RTX 3090 | A100 / H100 |
|---|---|---|---|
| ≤ 128 | Forward + Backward | ✅ | ✅ |
| 192 | Forward | ✅ | ✅ |
| 192 | Backward | ❌ (errors) | ✅ |
| 256 | Forward | ✅ | ✅ |
| 256 | Backward (dropout=0) | ✅ (≥ v2.5.5) | ✅ |
| 256 | Backward (dropout>0) | ❌ | ✅ |
If you are training a model with head_dim=192 on a consumer GPU and see a runtime error on the backward pass, the root cause is the GPU architecture, not the flash-attn version. Upgrading flash-attn will not fix this — you need A100/A800 or H100/H800 hardware. For head_dim=256 on consumer GPUs with dropout, upgrade to flash-attn ≥ 2.5.5 and set dropout to 0.0 during benchmarking.
Dtype, dropout, and backend mismatches
FlashAttention-2's CUDA path accepts FP16 and BF16. BF16 requires Ampere, Ada, or Hopper — attempting BF16 on Turing hardware will trigger an error or fallback, not a silent success. The official README documents this requirement directly.
Watch Out: If you pass FP32 tensors to
flash_attn_func, the library will not silently cast — it will raise a dtype error. If your benchmark wrapper autocasts to FP32 and you do not notice, the call fails, your benchmark loop catches the exception and skips the iteration, and you end up with throughput numbers that reflect no actual FlashAttention work. Always asserttensor.dtype in (torch.float16, torch.bfloat16)before the timed loop. For FlashAttention-3 FP8 benchmarks, confirm you are on Hopper with CUDA 12.3+ and explicitly using the FA-3 API — the FA-2 API does not expose an FP8 path.
Dropout interacts with backend dispatch in two places: it disables head_dim=256 backward on consumer GPUs (detailed above), and it can affect which kernel variant is selected internally. For benchmark purity, set dropout_p=0.0 unless your production workload requires dropout and you are specifically benchmarking that configuration.
FAQ on installation, support, and benchmarking
What are the minimum requirements for FlashAttention?
CUDA toolkit or ROCm toolkit, PyTorch 2.2+, packaging Python package, and Linux. For FlashAttention-3 beta: CUDA 12.3 minimum, CUDA 12.8 recommended, and H100/H800 hardware.
What GPUs support FlashAttention-2?
Ampere (A100, A800, RTX 3090), Ada (RTX 4090, RTX 4080), and Hopper (H100, H800). Turing GPUs (T4, RTX 2080) require the separate flash-attention-turing repository. Blackwell (RTX 5000 series) is not yet in the default architecture list.
Does FlashAttention run on Windows?
Not reliably via the official source-build path. Issue #2535 documents DLL PATH conflicts with the CUDA Toolkit on Windows. Use Linux or WSL2.
Why does my FlashAttention build fail with out-of-memory errors?
You are running too many parallel compile jobs. Set MAX_JOBS=8 (or lower) before building if your system has under 96 GB RAM.
How do I know if my benchmark is actually measuring FlashAttention?
Run the backend confirmation code before timing. Verify that flash_attn.__version__ loads without error, that torch.cuda.get_device_capability() returns a supported SM version, and that your tensor dtype is FP16 or BF16. Any mismatch risks timing a fallback kernel.
Can I use FP8 with FlashAttention-2?
No. FP8 is available only in the FlashAttention-3 beta forward pass, on Hopper hardware.
Is head_dim=192 safe for training on an RTX 4090?
Forward pass only. Backward pass at head_dim > 192 requires A100/A800 or H100/H800.
Sources and references
- Dao-AILab/flash-attention GitHub Repository — Official source for installation requirements, GPU architecture support, head-dim limits, Ninja build guidance, and
MAX_JOBStuning - FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (arXiv:2407.08608) — Source for FA-3 performance claims: 1.5–2.0× speedup on H100, 740 TFLOPs/s FP16, 1.2 PFLOPs/s FP8, and lower FP8 numerical error on Hopper
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (arXiv:2307.08691) — Core algorithmic reference for FlashAttention-2 architecture and tiling approach
- PyTorch official installation — Canonical source for PyTorch 2.2+ CUDA wheel index URLs
- NVIDIA CUDA Toolkit — Official CUDA 12.3 and 12.8 download and documentation
- flash-attention issue #2535 — Documents Blackwell SM_120 arch list omission and Windows DLL PATH conflicts
- flash-attention issue #2299 — Documents version-specific wheel naming and missing official wheels for newer PyTorch releases
- flash-attention issue #277 — Confirms head dim support up to 256 as of FlashAttention-2
Keywords: FlashAttention-2, FlashAttention-3 beta, PyTorch 2.2+, CUDA 12.3, CUDA 12.8, Ninja, NVIDIA H100, NVIDIA H800, A100, A800, ROCm, Linux, head dimension 192, head dimension 256, MAX_JOBS



