Skip to content
AxiomLogicaSearch
AI & ML

How to install and benchmark FlashAttention in PyTorch: requirements, head-dim limits, and common failure modes

FlashAttention installation is constrained by CUDA, PyTorch, Ninja, and GPU architecture support — and benchmark results are only trustworthy when head-dim limits, dtype support, and backend compatibility are matched to the target GPU, otherwise users hit build failures or misleading speed numbers.

How to install and benchmark FlashAttention in PyTorch: requirements, head-dim limits, and common failure modes
How to install and benchmark FlashAttention in PyTorch: requirements, head-dim limits, and common failure modes

At a glance: install and benchmark prerequisites

At a Glance: Time: 15–30 min with Ninja on a 64-core machine · Prereqs: Linux, CUDA toolkit, PyTorch 2.2+, packaging, psutil, ninja · Hardware: Ampere/Ada/Hopper for FlashAttention-2; H100/H800 required for FlashAttention-3 beta · RAM: 96 GB+ recommended for unconstrained builds; tune MAX_JOBS below that threshold

FlashAttention installation is gated by a specific intersection of GPU architecture, CUDA toolkit version, PyTorch version, and build toolchain. Getting any one of these wrong produces either a silent fallback to a different attention kernel — invalidating your benchmark — or a hard build failure. The official repository requires a CUDA toolkit or ROCm toolkit, PyTorch 2.2+, the packaging Python package, and Linux as the primary supported OS. FlashAttention-3 beta narrows the hardware requirement further: it targets Hopper-class GPUs (H100/H800) with CUDA 12.3 as the minimum and CUDA 12.8 recommended for peak performance.

The build toolchain matters more than most installation guides acknowledge. Without Ninja, building flash-attn on a 64-core machine takes approximately two hours. With Ninja, the same machine finishes in 3–5 minutes. On systems with less than 96 GB of RAM, an unconstrained parallel build can exhaust memory mid-compilation; MAX_JOBS must be lowered to prevent this. These are not edge cases — they are the two most common sources of avoidable pain in a flash-attn source build.

Prerequisites and compatibility matrix

The table below consolidates the version and architecture requirements before you run a single install command.

Requirement FlashAttention-2 (CUDA) FlashAttention-3 beta
GPU architectures Ampere, Ada, Hopper (A100, RTX 3090/4090, H100) Hopper only (H100/H800)
CUDA minimum Official repo requires a CUDA toolkit; match it to your PyTorch build 12.3 minimum
CUDA recommended Toolkit version should match the installed PyTorch CUDA runtime 12.8
PyTorch minimum 2.2+ 2.2+
OS Linux (primary); Windows limited Linux
Dtype (forward) FP16, BF16 FP16, BF16, FP8
Dtype (backward) FP16, BF16 FP16, BF16
Head dim max 256 256
Head dim >192 backward A100/A800 or H100/H800 only H100/H800
Head dim 256 backward on consumer GPUs Only with dropout disabled (≥ v2.5.5) N/A
ROCm support Yes (separate backends) No

As the official README states: "Requirements: CUDA toolkit or ROCm toolkit. PyTorch 2.2 and above. packaging Python package (pip install packaging)." The psutil and ninja packages are not listed as hard requirements but are effectively mandatory for a build that completes in a reasonable time.

Which GPUs and CUDA versions are actually supported

FlashAttention-2 with the CUDA backend supports Ampere, Ada, and Hopper GPUs. The README explicitly names A100, RTX 3090, RTX 4090, and H100 as representative examples. Turing GPUs (T4, RTX 2080) require the separate flash-attention-turing repository.

FlashAttention-3 beta is Hopper-exclusive. The FlashAttention-3 paper (arXiv:2407.08608) reports speedups of 1.5–2.0× over FlashAttention-2 on H100 GPUs, with FP16 reaching up to 740 TFLOPs/s (75% utilization) and FP8 reaching close to 1.2 PFLOPs/s. These numbers are architecture-specific to H100 and do not transfer to Ampere or Ada hardware under the FA-3 kernel path. CUDA 12.8 is the recommended toolkit version for realizing that peak throughput; running FA-3 on CUDA 12.3 is functional but may leave performance on the table.

Watch Out: FlashAttention-3 beta does not run on Ampere or Ada GPUs — not in a degraded mode, but not at all in the documented path. If you attempt to benchmark FA-3 on an A100, you are either running FA-2 silently or hitting an error. Confirm your GPU SM version before attributing any FA-3 performance claims. Additionally, head dim > 192 backward requires A100/A800 or H100/H800; running training on a consumer GPU (e.g., RTX 4090) with a head dimension of 256 and dropout enabled will fail or silently degrade from flash-attn 2.5.5 forward.

Head-dim limits, dtype support, and backward-pass edge cases

FlashAttention-2 supports all head dimensions up to 256 for the forward pass, but the backward pass has GPU-conditional limits that are frequently misread.

Head dim Forward (all supported GPUs) Backward (Ampere/Ada consumer) Backward (A100/A800/H100/H800)
≤ 128 ✅ FP16, BF16
160, 192 ✅ FP16, BF16 ❌ (fails)
256 ✅ FP16, BF16 ✅ only if dropout=False (≥ v2.5.5)

The README documents this directly: "Head dim > 192 backward requires A100/A800 or H100/H800." For head dim 256: "Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5."

BF16 requires Ampere, Ada, or Hopper — it is unavailable on Volta or Turing. FP8 is a FlashAttention-3 forward-only feature; the FA-3 paper validates that FP8 FlashAttention-3 achieves lower numerical error than a baseline FP8 attention on Hopper hardware, but FP8 backward is not yet supported in the beta. Do not assume FP8 availability on any non-Hopper GPU or in any FA-2 path.

Install the required build tools and Python packages

Install Python dependencies and a compatible PyTorch build before touching flash-attn itself. The order matters: PyTorch must already be present with a matching CUDA runtime because flash-attn links against it at build time.

# Step 1: Install a CUDA 12.x-compatible PyTorch (adjust cu128 to match your toolkit)
$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

# Step 2: Install the required Python build dependencies
$ pip install packaging psutil ninja

With these four packages in place, you have the minimum dependency set the official repository requires. Do not proceed to the flash-attn install until the PyTorch import works and torch.cuda.is_available() returns True.

Verify the CUDA toolkit and compiler toolchain first

Toolkit/runtime mismatches are the root cause of the majority of build failures. Before compiling, confirm that the nvcc version, the system CUDA toolkit, and the CUDA version embedded in the installed PyTorch wheel all agree enough to link cleanly.

# Check the system CUDA toolkit version
$ nvcc --version

# Check the CUDA version PyTorch was compiled against
$ python -c "import torch; print(torch.version.cuda)"

# Check that PyTorch can see the GPU
$ python -c "import torch; print(torch.cuda.get_device_name(0))"

# Check CUDA toolkit path (should point to the same major version)
$ echo $CUDA_HOME
$ ls $CUDA_HOME/bin/nvcc

The local toolkit and the PyTorch wheel/runtime CUDA version must be compatible before compilation. A mismatch between the local toolkit and the wheel's embedded runtime is the most common build-time failure mode.

Pro Tip: The CUDA version in a PyTorch wheel (e.g., cu128 in the wheel filename) reflects the runtime the wheel was compiled against, not the maximum CUDA toolkit version you can use. Install the CUDA toolkit matching that tag exactly. If your system has CUDA 12.3 but your PyTorch wheel is cu128, the flash-attn C++ extension will link against the wrong headers and fail with cryptic symbol errors.

Note that Blackwell-class GPUs (SM 120, RTX 5000 series) are not in the default SM architecture list as of the current repo state. Issue #2535 documents that flash-attn does not build or run on RTX 5070 Ti out of the box, with the error message explicitly naming the missing SM_120 architecture.

Install Ninja and tune MAX_JOBS for the available RAM

Ninja is the single highest-leverage build optimization available. The repository README is unambiguous: "Without ninja, building flash-attn takes 2 hours on a 64-core machine. With ninja, building flash-attn takes 3-5 minutes on a 64-core machine."

# Install Ninja via pip (preferred — guarantees the Python build system finds it)
$ pip install ninja

# Verify Ninja is on PATH
$ ninja --version

# On memory-constrained hosts (< 96 GB RAM), cap parallel jobs
# A reasonable starting point: 1 job per ~4 GB of RAM
$ export MAX_JOBS=8   # adjust to your system's available RAM

Production Note: The 3–5 minute build time is specific to a 64-core machine with Ninja and adequate RAM. On an 8-core machine or a cloud instance with 32 GB RAM, expect 20–40 minutes even with Ninja. Setting MAX_JOBS too high on a low-RAM host causes the kernel to OOM-kill compiler processes mid-build, producing misleading errors that look like source code bugs rather than resource exhaustion. Start conservatively (8–16 jobs) and increase only if the build completes successfully.

Build and install FlashAttention from source on Linux

The documented install path from the official repository uses --no-build-isolation, which tells pip to use the already-installed PyTorch and CUDA headers rather than pulling a fresh build environment. This flag is required — without it, pip may resolve a different PyTorch version or miss the local CUDA toolkit.

# Primary install command — requires packaging, psutil, ninja, and PyTorch already installed
$ pip install flash-attn --no-build-isolation

If you need a specific version or the wheel for your environment does not exist on PyPI:

# Build directly from source after cloning
$ git clone https://github.com/Dao-AILab/flash-attention.git
$ cd flash-attention
$ python setup.py install
# Environment variables relevant to the source build
MAX_JOBS: "8"          # Reduce on < 96 GB RAM hosts
CUDA_HOME: "/usr/local/cuda-12.8"   # Point to the correct toolkit version
TORCH_CUDA_ARCH_LIST: "8.0;8.6;9.0"  # Explicitly list target SM architectures

The TORCH_CUDA_ARCH_LIST variable is worth setting explicitly if you are building for a specific GPU fleet. Leaving it unset causes the build system to compile for every supported SM version, which adds significant time. For an H100-only cluster, 9.0 alone is sufficient.

Use the official wheel or source build for your exact CUDA stack

Prebuilt wheels exist for common (PyTorch, CUDA, Python, ABI) combinations, but coverage is incomplete. As documented in issue #2299, there is no official prebuilt wheel for every PyTorch release — community wheels fill some gaps (e.g., flash_attn-2.8.3+cu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64), but these are not guaranteed to match your environment's ABI settings.

Choose prebuilt wheels when: - The wheel's CUDA tag (e.g., cu128), PyTorch tag (e.g., torch2.6), Python version, and cxx11abi flag all match your environment exactly. - You cannot afford a 20–40 minute source build in a CI or deployment pipeline. - You are not targeting a non-standard SM architecture.

Choose source compilation when: - No matching official wheel exists for your PyTorch/CUDA/Python combination. - You need a specific SM architecture list (e.g., SM 90 only for an H100 fleet). - You are building FlashAttention-3 beta features not yet in a stable wheel release. - You need to control compiler flags for reproducibility.

Do not benchmark a wheel build against a source build without confirming that both activated the same kernel paths — differences in compiler optimization flags can produce materially different throughput numbers.

What to do on Windows, ROCm, or non-Hopper systems

Watch Out: Windows is not a supported environment in the official source-build documentation. Issue #2535 explicitly documents DLL PATH conflicts between the Windows CUDA Toolkit installation and flash-attn's build process. If you must run on Windows, expect manual environment variable surgery and no guarantee of a working build. WSL2 with a Linux image is the supported workaround.

ROCm: The repository documents ROCm support with two separate backends. ROCm builds follow a different dependency chain than CUDA builds, and CUDA benchmarking guidance does not transfer. If your infrastructure runs AMD GPUs, treat the ROCm path as a separate installation track and benchmark it independently.

Non-Hopper systems and FA-3: FlashAttention-3 beta is Hopper-only. An A100 (Ampere, SM 80) cannot run the FA-3 kernel path. Do not benchmark FlashAttention-3 on Ampere and report the result as a FlashAttention-3 number — you will be timing FlashAttention-2 or a fallback kernel. Turing GPUs (T4, RTX 2080) require the separate flash-attention-turing repository, which covers only a core feature subset.

Run a correctness check before you benchmark

Timing a kernel that is silently falling back to PyTorch's standard attention implementation produces numbers that look plausible but are meaningless for evaluating FlashAttention. Run a functional check first.

import torch
from flash_attn import flash_attn_func

# Minimal smoke test: batch=2, heads=8, seq_len=512, head_dim=64, dtype=fp16
# Requires CUDA device — must run on Ampere/Ada/Hopper
batch, heads, seqlen, head_dim = 2, 8, 512, 64
dtype = torch.float16
device = "cuda"

q = torch.randn(batch, seqlen, heads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch, seqlen, heads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch, seqlen, heads, head_dim, device=device, dtype=dtype)

out = flash_attn_func(q, k, v, causal=True)

# Verify output shape and that no NaN/Inf leaked through
assert out.shape == (batch, seqlen, heads, head_dim), f"Shape mismatch: {out.shape}"
assert not out.isnan().any(), "NaN detected in output — check dtype and GPU compatibility"
assert not out.isinf().any(), "Inf detected in output — check head_dim and GPU constraints"
print("Smoke test passed. Shape:", out.shape, "| dtype:", out.dtype)

Expected output: Smoke test passed. Shape: torch.Size([2, 512, 8, 64]) | dtype: torch.float16. Any import error here points to a build failure. A shape assertion error points to an API version mismatch. NaN output with BF16 on Turing-class hardware indicates an unsupported dtype/architecture combination.

Confirm backend selection and tensor shapes

Before timing, verify that the kernel dispatching to the intended FlashAttention backend matches your GPU, dtype, and head dimension. This step is what most benchmark guides omit, and it is why published numbers are frequently not reproducible.

import torch
from flash_attn.flash_attn_interface import flash_attn_func
import flash_attn

# Print the installed version and CUDA backend info
print(f"flash-attn version : {flash_attn.__version__}")
print(f"PyTorch version    : {torch.__version__}")
print(f"CUDA version       : {torch.version.cuda}")
print(f"GPU                : {torch.cuda.get_device_name(0)}")
print(f"GPU SM arch        : sm_{torch.cuda.get_device_capability(0)[0]}{torch.cuda.get_device_capability(0)[1]}")

# Confirm the tensor configuration you intend to benchmark
head_dim = 128
dtype = torch.bfloat16  # BF16 requires Ampere/Ada/Hopper
causal = True
dropout = 0.0  # Dropout=0 required for head_dim=256 backward on consumer GPUs

print(f"\nBenchmark config:")
print(f"  head_dim : {head_dim}")
print(f"  dtype    : {dtype}")
print(f"  causal   : {causal}")
print(f"  dropout  : {dropout}")

The SM arch printout tells you whether you are on Hopper (SM 90), Ada (SM 89), or Ampere (SM 80/86). Cross-reference this against the head-dim compatibility table before running backward-pass benchmarks.

Benchmark FlashAttention without misleading results

Benchmark discipline for FlashAttention requires holding GPU architecture, dtype, head_dim, causal mode, dropout, and sequence length constant across all compared runs. The FlashAttention-3 paper reports 1.5–2.0× speedups and up to 740 TFLOPs/s in FP16 on H100 hardware specifically — these numbers should not be quoted for A100 runs or compared against FA-3 results on different hardware without explicit labeling.

The table below shows the fields every benchmark record must capture to be reproducible.

Field Why it matters Example value
GPU model Kernel path and throughput ceiling differ by architecture H100 SXM5
flash-attn version Kernel code changes between versions 2.8.3
dtype FP16 and BF16 achieve different FLOP rates BF16
head_dim Determines kernel variant; backward has GPU constraints above 192 128
seq_len FlashAttention's memory advantage grows quadratically with seq_len 4096
batch size Affects GPU occupancy 4
causal Causal masking halves effective compute load True
dropout Changes kernel path for head_dim 256 on consumer GPUs 0.0
warmup iters First iterations include JIT compilation; exclude from timing 10

Always use at least 10 warmup iterations before measuring. The first kernel invocation includes CUDA graph compilation or JIT overhead that can add 100 ms+ to a measurement that the steady-state kernel completes in under 1 ms.

Measure the right metrics for your workload

Training workloads care about backward throughput (tokens/sec) and peak memory; inference workloads care about time-to-first-token (TTFT) and sustained decode throughput. Report them separately — a single "speed" number conflates two different performance regimes.

Config seq_len dtype Metric FlashAttention-2 (A100) FlashAttention-3 (H100)
Forward only 2048 FP16 TFLOPs/s H100 paper-backed figures only H100 paper-backed figures only
Forward only 8192 FP16 TFLOPs/s H100 paper-backed figures only H100 paper-backed figures only
Forward+Backward 2048 BF16 TFLOPs/s H100 paper-backed figures only H100 paper-backed figures only
Forward only 2048 FP8 TFLOPs/s N/A H100 paper-backed figures only

FA-3 headline numbers from arXiv:2407.08608: FP16 up to 740 TFLOPs/s, FP8 close to 1.2 PFLOPs/s on H100. FA-2 A100 figures are representative of published benchmarks; measure your own for authoritative comparisons.

Memory use should always be measured with the same dtype and head_dim as the timing run. FlashAttention's core advantage — O(N) memory in sequence length versus O(N²) for standard attention — becomes measurable above seq_len ≈ 2048; benchmarking at seq_len=512 will show minimal memory difference and may make FlashAttention look slower than standard attention due to kernel launch overhead.

Avoid apples-to-oranges comparisons with PyTorch SDPA

PyTorch's scaled_dot_product_attention (SDPA) can dispatch to multiple backends including FlashAttention, cuDNN, and an efficient math kernel depending on dtype, head_dim, and device capability. Benchmarking your custom flash_attn_func call against torch.nn.functional.scaled_dot_product_attention without controlling the SDPA dispatch is not a fair comparison — SDPA may itself route to flash-attn internally.

Choose direct flash_attn_func vs. PyTorch SDPA comparison when: - You have explicitly disabled all SDPA backends except the one you are testing (torch.backends.cuda.enable_flash_sdp(False) etc.) and want to measure the raw kernel difference. - You are quantifying the overhead of PyTorch's dispatch layer for latency-sensitive inference.

Choose FlashAttention vs. cuDNN attention when: - Your inference stack already uses cuDNN and you want to determine whether switching kernel implementations is worth the engineering cost. - Both paths support the same dtype and head_dim on the same GPU.

Choose FlashAttention vs. Triton attention kernel when: - You are building a custom kernel and need to understand where cycles are spent relative to a reference implementation. - Both kernels are compiled for the same SM architecture with equivalent optimization flags.

The controlling variables — dtype, causal flag, dropout rate, head_dim, and whether the backward pass is included — must be held constant across all compared runs. Compiler/build settings (source vs. wheel, TORCH_CUDA_ARCH_LIST) should also be reported when publishing benchmark results, since available kernel variants can differ between builds.

Common failure modes and exact fixes

Build and runtime failures almost always fall into one of five categories: missing Ninja (compile-time), CUDA version mismatch (compile-time or import-time), unsupported GPU architecture (runtime or build-time), head-dim/backward constraint violations (runtime), and dtype/backend mismatch (runtime, often silent).

Watch Out: The most dangerous failure mode is silent fallback — flash-attn imports successfully, your code runs, and your benchmark produces numbers, but the actual kernel being executed is not FlashAttention. This happens when dtype, head_dim, or GPU constraints are not satisfied and the library falls back to a standard attention path. Always run the backend confirmation code in the "Confirm backend selection" section before treating any timing number as a FlashAttention result.

Error symptom Root cause Version / constraint Fix
Build takes 2+ hours Ninja not installed All versions pip install ninja before building
nvcc fatal: unsupported arch SM not in arch list Blackwell (SM 120), others Set TORCH_CUDA_ARCH_LIST explicitly
undefined symbol at import CUDA toolkit / wheel version mismatch All versions Match nvcc --version with torch.version.cuda
OOM during compile Too many parallel jobs < 96 GB RAM hosts export MAX_JOBS=8 (or lower)
RuntimeError on backward at head_dim=192 Wrong GPU class FA-2, all versions Use A100/A800 or H100/H800 for backward at head_dim > 192
NaN output with dropout + head_dim=256 Consumer GPU + dropout constraint FA-2 < 2.5.5 Upgrade to ≥ 2.5.5 and set dropout=0.0
DLL load error on Windows PATH/DLL conflicts All versions Use Linux or WSL2
Windows CUDA Toolkit PATH conflict Windows source-build environment Official docs Use Linux or WSL2 and align PATH manually if testing locally

Ninja is missing or build times explode

If ninja --version returns command not found and you start a flash-attn source build, you are committing to an hours-scale compile. The official README is explicit: "Install ninja if you don't have it. It's much faster using ninja than without. Without ninja, building flash-attn takes 2 hours on a 64-core machine. With ninja, building flash-attn takes 3-5 minutes on a 64-core machine."

Pro Tip: Install Ninja via pip install ninja rather than the system package manager. The pip version is found by Python's build system automatically without any PATH configuration. If you install Ninja via apt or yum but Python's build system cannot find it (because NINJA env var is unset or the binary is not on the build PATH), the speed benefit disappears. After installing, confirm with python -c "import ninja; print(ninja.BIN_PATH)". On machines with fewer than 96 GB RAM, set MAX_JOBS to a value where (MAX_JOBS × ~4 GB per compiler process) stays within available RAM — OOM mid-build produces misleading errors that waste diagnostic time.

Head-dim and backward-pass errors on the wrong GPU

A training run with head_dim=192 or head_dim=256 will fail or silently degrade on consumer Ampere/Ada GPUs unless specific conditions are met. The constraints are:

head_dim Operation RTX 4090 / RTX 3090 A100 / H100
≤ 128 Forward + Backward
192 Forward
192 Backward ❌ (errors)
256 Forward
256 Backward (dropout=0) ✅ (≥ v2.5.5)
256 Backward (dropout>0)

If you are training a model with head_dim=192 on a consumer GPU and see a runtime error on the backward pass, the root cause is the GPU architecture, not the flash-attn version. Upgrading flash-attn will not fix this — you need A100/A800 or H100/H800 hardware. For head_dim=256 on consumer GPUs with dropout, upgrade to flash-attn ≥ 2.5.5 and set dropout to 0.0 during benchmarking.

Dtype, dropout, and backend mismatches

FlashAttention-2's CUDA path accepts FP16 and BF16. BF16 requires Ampere, Ada, or Hopper — attempting BF16 on Turing hardware will trigger an error or fallback, not a silent success. The official README documents this requirement directly.

Watch Out: If you pass FP32 tensors to flash_attn_func, the library will not silently cast — it will raise a dtype error. If your benchmark wrapper autocasts to FP32 and you do not notice, the call fails, your benchmark loop catches the exception and skips the iteration, and you end up with throughput numbers that reflect no actual FlashAttention work. Always assert tensor.dtype in (torch.float16, torch.bfloat16) before the timed loop. For FlashAttention-3 FP8 benchmarks, confirm you are on Hopper with CUDA 12.3+ and explicitly using the FA-3 API — the FA-2 API does not expose an FP8 path.

Dropout interacts with backend dispatch in two places: it disables head_dim=256 backward on consumer GPUs (detailed above), and it can affect which kernel variant is selected internally. For benchmark purity, set dropout_p=0.0 unless your production workload requires dropout and you are specifically benchmarking that configuration.

FAQ on installation, support, and benchmarking

What are the minimum requirements for FlashAttention?

CUDA toolkit or ROCm toolkit, PyTorch 2.2+, packaging Python package, and Linux. For FlashAttention-3 beta: CUDA 12.3 minimum, CUDA 12.8 recommended, and H100/H800 hardware.

What GPUs support FlashAttention-2?

Ampere (A100, A800, RTX 3090), Ada (RTX 4090, RTX 4080), and Hopper (H100, H800). Turing GPUs (T4, RTX 2080) require the separate flash-attention-turing repository. Blackwell (RTX 5000 series) is not yet in the default architecture list.

Does FlashAttention run on Windows?

Not reliably via the official source-build path. Issue #2535 documents DLL PATH conflicts with the CUDA Toolkit on Windows. Use Linux or WSL2.

Why does my FlashAttention build fail with out-of-memory errors?

You are running too many parallel compile jobs. Set MAX_JOBS=8 (or lower) before building if your system has under 96 GB RAM.

How do I know if my benchmark is actually measuring FlashAttention?

Run the backend confirmation code before timing. Verify that flash_attn.__version__ loads without error, that torch.cuda.get_device_capability() returns a supported SM version, and that your tensor dtype is FP16 or BF16. Any mismatch risks timing a fallback kernel.

Can I use FP8 with FlashAttention-2?

No. FP8 is available only in the FlashAttention-3 beta forward pass, on Hopper hardware.

Is head_dim=192 safe for training on an RTX 4090?

Forward pass only. Backward pass at head_dim > 192 requires A100/A800 or H100/H800.

Sources and references


Keywords: FlashAttention-2, FlashAttention-3 beta, PyTorch 2.2+, CUDA 12.3, CUDA 12.8, Ninja, NVIDIA H100, NVIDIA H800, A100, A800, ROCm, Linux, head dimension 192, head dimension 256, MAX_JOBS

Was this guide helpful?

The weekly brief.

One email each Sunday with what we tested, what we'd buy, and what to skip. No filler.

Share: X · LinkedIn · Reddit