Engineering the Quantized Johnson-Lindenstrauss (QJL) Transform for Distributed Inference

18 min read · Published Apr 5, 2026, 3:54 PM

Introduction: Eliminating KV-Cache Bottlenecks in Production

VRAM is the binding constraint on long-context LLM inference throughput. As sequence lengths push past 8k tokens, the Key-Value (KV) cache—not the model weights—becomes the primary VRAM consumer. A single Llama-3-70B inference pass with a 32k context window can allocate tens of gigabytes exclusively to KV cache tensors, crowding out batch capacity and forcing horizontal scaling that multiplies infrastructure cost. Effective KV Cache Optimization is essential to mitigate these resource constraints and maximize hardware utilization.

Traditional quantization methods attack this problem by reducing the numerical precision of cached tensors from FP16 to INT8 or FP8. The reduction in activation storage is real, but these methods impose a hidden tax: every quantized tensor requires companion metadata—per-tensor or per-group scale and zero-point parameters—that must live in VRAM alongside the compressed data. For large-batch, long-context deployments, this overhead is non-trivial and compounds with model depth.

The Quantized Johnson-Lindenstrauss (QJL) transform eliminates this metadata entirely. By applying a JL preconditioning matrix before reducing keys and values to their sign bits, QJL achieves unbiased inner product estimation without storing any quantization constants. The result is a theoretically pure 1-bit representation with no auxiliary state.

Technical Note: QJL stores only packed sign bits. There are no scale factors, no zero points, and no group statistics to persist—the JL transform handles the statistical normalization implicitly.

Memory Overhead Comparison: Traditional Quantization vs. QJL

Method Bits per Element Quantization Metadata Effective Overhead Notes
FP32 KV Cache 32 None 32 bits/element Baseline
FP16 KV Cache 16 None 16 bits/element Standard production default
INT8 (per-tensor) 8 1× scale + 1× zero-point per tensor ~8.06 bits/element Minimal overhead
INT8 (per-group, g=128) 8 Scale + zero-point per 128 elements ~8.25 bits/element Common for accuracy
FP8 (per-group, g=64) 8 Scale per 64 elements ~8.5 bits/element Increasingly standard
QJL (1-bit sign) 1 None 1 bit/element No metadata, unbiased estimator

Per ArXiv 2406.03482, QJL achieves a theoretical 32× memory reduction compared to FP32 KV caches. In practice, the effective reduction against FP16 is 16×. Related work under the TurboQuant framework benchmarks at least a 6× KV cache memory reduction in production settings—a figure that directly translates into proportional increases in maximum concurrent batch size or supported context length on fixed hardware.

The architectural mechanics enabling this are straightforward: because the JL transform is a fixed random projection applied at write time, the reader can reconstruct distributional properties of the original vector from the sign pattern alone. No per-batch calibration is needed. This is fundamentally different from learned quantization schemes, which require a calibration dataset and per-deployment parameter tuning.


The QJL Mathematical Framework: Preconditioning and Sign-Bit Quantization

The core claim of QJL is that you can estimate the inner product between two high-dimensional vectors using only their projected sign bits. Here is the formal basis.

Let x, y ∈ ℝᵈ be key and query vectors. Let P ∈ ℝᵐˣᵈ be a Johnson-Lindenstrauss transform matrix—typically a random matrix with entries drawn i.i.d. from N(0, 1/m) or a sparse binary variant. The JL transform guarantees that for any pair of vectors, pairwise distances (and by extension, inner products) are approximately preserved in the projected space with high probability.

The QJL inner product estimator exploits the following property:

E[sign(Px) · sign(Py)] = (2/π) · arcsin( x^T y / (‖x‖ · ‖y‖) )

By storing only sign(Pk) and sign(Pv) for each key k and value v, and compensating for the norms ‖k‖ and ‖v‖ (which require negligible scalar storage—one float32 per vector), the attention logit computation becomes a bitwise population-count operation rather than a full floating-point dot product. The estimator is unbiased by construction: no systematic error accumulates across layers or heads, and no calibration step shifts the distribution.

The practical implication is that the projection matrix P is fixed at initialization and shared across all inference requests. It is not a learned parameter; it is a structural random seed. This eliminates the need for recomputation, fine-tuning, or dataset-specific calibration—properties that make QJL uniquely suited for production pipelines where model weights are frozen.

Memory Constraint: The JL projection matrix P itself must reside in VRAM. For m=256 projected dimensions and d=128 head dimensions, P is 256×128 = 32,768 FP16 parameters per head—negligible relative to cache savings at long context.

The preconditioning matrix P must satisfy the Johnson-Lindenstrauss lemma constraints to preserve distance properties. Specifically, for any ε ∈ (0, 1) and a set of n vectors, a projection to m = O(log(n)/ε²) dimensions guarantees that all pairwise distances are preserved within factor (1 ± ε). For KV cache use cases, the "n vectors" are the token embeddings per head—manageable even at 32k sequence lengths.


Architecting the QJL Kernel: From Theory to CUDA Implementation

Moving QJL from mathematics to a deployed Distributed Inference kernel requires three concrete implementations: a projection kernel, a sign-packing kernel, and a packed-binary attention kernel. The first two execute at cache-write time; the third executes at attention-compute time. Optimizing this implementation is critical for robust LLM Infrastructure performance, as it shifts compute requirements from memory-bound operations to register-based bitwise logic.

The sign-packing step is the critical piece that delivers memory savings. After projecting a key vector through P, you reduce each projected dimension to its sign bit and pack 32 consecutive bits into a single uint32_t. This is where CUDA intrinsics provide decisive throughput advantages.

// qjl_sign_pack.cu
// Packs sign bits of projected keys into uint32 words.
// Each thread handles one group of 32 projected dimensions.

#include <cuda_fp16.h>
#include <stdint.h>

__global__ void qjl_sign_pack_kernel(
    const __half* __restrict__ projected,  // [num_heads, seq_len, proj_dim]
    uint32_t*     __restrict__ packed,     // [num_heads, seq_len, proj_dim/32]
    const int proj_dim,
    const int seq_len
) {
    // Each warp processes one token's projected vector for one head
    const int head_idx = blockIdx.x;
    const int tok_idx  = blockIdx.y;
    const int warp_lane = threadIdx.x % 32;
    const int group_idx = threadIdx.x / 32;  // which 32-dim group this warp handles

    const int offset = (head_idx * seq_len + tok_idx) * proj_dim
                       + group_idx * 32 + warp_lane;

    // Load one projected scalar per thread in the warp
    float val = (offset < (head_idx * seq_len + tok_idx + 1) * proj_dim)
                ? __half2float(projected[offset])
                : 0.0f;

    // Each thread contributes its sign bit; warp-level ballot packs 32 bits atomically
    // __ballot_sync returns a uint32_t where bit i == 1 iff lane i's predicate is true
    uint32_t sign_word = __ballot_sync(0xFFFFFFFF, val >= 0.0f);

    // Only lane 0 of each warp writes the packed word—no write contention
    if (warp_lane == 0) {
        const int packed_offset = (head_idx * seq_len + tok_idx) * (proj_dim / 32)
                                  + group_idx;
        packed[packed_offset] = sign_word;
    }
}

__ballot_sync is the key intrinsic here. It performs a warp-synchronous predicate evaluation and returns a single uint32_t with one bit set per lane, eliminating any atomic write contention. Every 32-thread warp produces exactly one uint32_t in a single cycle. This is architecturally superior to per-thread writes with atomic OR operations, which serialize on the memory controller.

For the attention logit computation, the inner product estimator reduces to a bitwise XOR followed by __popc (population count):

// Approximate dot product via Hamming distance on sign bits
// popcount(~XOR(a,b)) counts matching sign bits == cosine-similarity proxy
__device__ float qjl_dot_estimate(
    const uint32_t* q_packed,  // packed query signs [proj_dim/32]
    const uint32_t* k_packed,  // packed key signs   [proj_dim/32]
    const float q_norm,
    const float k_norm,
    const int num_words        // proj_dim / 32
) {
    int match_count = 0;
    for (int w = 0; w < num_words; ++w) {
        // XOR gives 1 where signs differ; ~XOR gives 1 where signs match
        match_count += __popc(~(q_packed[w] ^ k_packed[w]));
    }
    // Map match fraction to cosine estimate: 2*(matches/total) - 1
    float cos_est = 2.0f * (float)match_count / (float)(num_words * 32) - 1.0f;
    // Scale by norms to recover inner product estimate
    return q_norm * k_norm * cos_est;
}

Technical Warning: Both kernels require Compute Capability 8.0 (Ampere) or higher. __ballot_sync exists on earlier architectures, but the memory subsystem optimizations for 128-bit coalesced loads—critical for peak throughput—require Ampere's revised L2 cache and HBM2e/HBM3 controllers.

Optimizing Kernel Performance for NVIDIA GPUs

Memory coalescing is the dominant performance variable for the packed binary cache. The compressed buffers are 1-bit-per-element, meaning 128 bytes of contiguous memory holds 1,024 sign bits—representing 1,024 projected dimensions for one token. A single LDG.E.128 instruction (128-bit load) on an H100 loads this entire chunk in one transaction, saturating the memory pipeline. As modern LLM Infrastructure scales, ensuring that memory access patterns align with HBM3 throughput is mandatory for sustaining high-concurrency requests.

To guarantee coalesced access, the packed buffer must be laid out with the innermost dimension being the proj_dim/32 words, and the stride between tokens must be a multiple of 16 bytes (128 bits). Any other layout forces the memory controller to issue multiple partial transactions per warp.

graph TD
    subgraph VRAM["VRAM Layout — Compressed KV Cache"]
        direction TB
        A["Head 0 | Token 0 | Words [0..7]<br/>= 256 sign bits = 256 proj dims"]
        B["Head 0 | Token 1 | Words [0..7]"]
        C["Head 0 | Token N | Words [0..7]"]
        D["Head 1 | Token 0 | Words [0..7]"]
        E["..."]
        A --> B --> C --> D --> E
    end

    subgraph Meta["Per-Vector Metadata (FP32)"]
        F["Head 0 | Token 0 | key_norm (4 bytes)"]
        G["Head 0 | Token 1 | key_norm (4 bytes)"]
        H["Head 0 | Token N | key_norm (4 bytes)"]
    end

    subgraph Proj["Projection Matrix (FP16, shared)"]
        I["P: [proj_dim × head_dim]<br/>Fixed random seed, read-only"]
    end

    VRAM -.->|"LDG.E.128 (128-bit loads)"| J["L2 Cache → SM Registers"]
    Meta -.->|"Scalar load per token"| J
    Proj -.->|"Cached in shared memory"| J

The projection matrix P must be pinned to shared memory within the attention kernel. At proj_dim=256 and head_dim=128 with FP16, P occupies 64KB—exactly fitting within the configurable shared memory allocation on H100 SMs (up to 228KB with cudaFuncSetAttribute). Loading P from L2 on every attention call would saturate the cache hierarchy; shared memory residence eliminates this cost entirely.

Handling Outlier Counts and Precision Tuning

Sign-bit quantization fails on outlier activations. In Llama-3 models (and transformer attention heads generally), a small fraction of channels carry disproportionately large magnitudes—typically driven by specific embedding dimensions associated with positional encodings or attention sink tokens. Forcing these through the JL-then-sign pipeline introduces systematic error that accumulates into perplexity degradation.

The solution is selective outlier bypass: identify the top-K channels by activation magnitude and store them in FP16, bypassing QJL entirely. By isolating these high-magnitude channels, we implement a targeted form of KV Cache Optimization, ensuring that the model retains its predictive precision while enjoying the throughput benefits of sign-bit compression. Optimal perplexity for Llama-3 models is maintained by reserving fewer than 5% of channels as unquantized outliers. For a head dimension of 128, this corresponds to outlier_count ≤ 6 channels.

# qjl_config.py
# Computes QJL kernel configuration for a given model checkpoint.
# Run once during deployment setup; output is baked into the kernel build.

import torch
from transformers import AutoModelForCausalLM
from typing import Tuple

def calibrate_outlier_count(
    model_name: str,
    calibration_texts: list[str],
    tokenizer,
    head_dim: int,
    target_outlier_fraction: float = 0.04,  # <5% threshold for Llama-3
    group_size_multiple: int = 32,           # must be multiple of warp size
) -> Tuple[int, int]:
    """
    Runs forward passes on calibration texts to identify high-magnitude
    KV channels. Returns (outlier_count, group_size) for kernel configuration.
    """
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
    model.eval().cuda()

    # Accumulate per-channel activation magnitudes across calibration samples
    channel_magnitudes = torch.zeros(head_dim, device="cuda")

    hooks = []
    def make_hook(name):
        def hook(module, input, output):
            # output shape: [batch, seq_len, num_heads, head_dim]
            # Average absolute value per channel across batch and seq dims
            channel_magnitudes.add_(output.abs().mean(dim=(0, 1, 2)))
        return hook

    # Attach hooks to all K projection outputs
    for layer in model.model.layers:
        h = layer.self_attn.k_proj.register_forward_hook(make_hook("k"))
        hooks.append(h)

    with torch.no_grad():
        for text in calibration_texts:
            inputs = tokenizer(text, return_tensors="pt").to("cuda")
            model(**inputs)

    for h in hooks:
        h.remove()

    # outlier_count: number of channels exceeding the fraction threshold
    sorted_mags, _ = channel_magnitudes.sort(descending=True)
    outlier_count = int(head_dim * target_outlier_fraction)

    # group_size must be a multiple of 32 to prevent warp divergence
    # during sign-bit accumulation; set to largest valid divisor of
    # (head_dim - outlier_count) that is >= 32
    quantized_dim = head_dim - outlier_count
    group_size = quantized_dim - (quantized_dim % group_size_multiple)
    group_size = max(group_size, group_size_multiple)

    print(f"head_dim={head_dim}, outlier_count={outlier_count}, group_size={group_size}")
    print(f"Effective quantized dims: {quantized_dim} ({100*quantized_dim/head_dim:.1f}%)")
    return outlier_count, group_size

# Example for Llama-3-8B (head_dim=128):
# outlier_count=5, group_size=128 → 3.9% outliers, full remaining dim in one group

Technical Warning: Group size must be a multiple of 32. Any value that is not a multiple of the warp size causes warp divergence during the sign-bit ballot operation, degrading kernel throughput by 2–4× depending on divergence frequency.


Production Integration: Injecting QJL into Existing Inference Pipelines

Integrating QJL into vLLM or TGI requires overriding the attention kernel dispatch path to accept packed binary buffers rather than standard FP16 KV tensors. Neither framework provides a first-class hook for this; the integration requires subclassing the attention backend and registering it before the model is initialized. Within the context of Distributed Inference, this backend must correctly manage the state of the compressed keys across multiple GPU shards to ensure coherence during the attention phase.

The core abstraction is a QJLAttentionWrapper that intercepts key and value tensors at write time, projects and packs them, and presents a custom forward method that uses the binary attention kernel.

# qjl_attention_wrapper.py
# Replaces the default PagedAttention backend in vLLM with QJL-compressed cache.
# Compatible with vLLM >= 0.4.0 with PagedAttention V2 backend.

import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from qjl_kernels import (  # compiled CUDA extension
    qjl_project_and_pack,   # runs P @ x and sign-packs result
    qjl_packed_attention,   # computes attention using binary keys
)

class QJLAttentionBackend(AttentionBackend):
    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        proj_dim: int,
        outlier_count: int,
        group_size: int,
        scale: float,
        projection_matrix: torch.Tensor,  # [proj_dim, head_dim - outlier_count], FP16
    ):
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.proj_dim = proj_dim
        self.outlier_count = outlier_count
        self.group_size = group_size
        self.scale = scale
        # Projection matrix is read-only; pin to GPU and mark as non-trainable
        self.P = projection_matrix.cuda().requires_grad_(False)

    def forward(
        self,
        query: torch.Tensor,          # [batch, seq, num_heads, head_dim] FP16
        key: torch.Tensor,            # [batch, seq, num_heads, head_dim] FP16
        value: torch.Tensor,          # [batch, seq, num_heads, head_dim] FP16
        kv_cache: tuple,              # (packed_keys, packed_vals, outlier_keys, outlier_vals, norms)
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:

        packed_k_cache, packed_v_cache, outlier_k, outlier_v, k_norms = kv_cache

        # At prefill: project and pack new keys/values into the compressed cache
        if attn_metadata.is_prompt:
            # Split key into outlier channels (FP16) and quantizable channels
            outlier_keys_new = key[..., :self.outlier_count]           # keep as FP16
            quant_keys_new   = key[..., self.outlier_count:]           # will be sign-packed

            # Compute per-vector L2 norms before quantization (stored as FP32 scalars)
            k_norms_new = quant_keys_new.norm(dim=-1, keepdim=True).float()

            # Project and pack: calls the compiled CUDA kernel from qjl_sign_pack_kernel
            new_packed_k = qjl_project_and_pack(quant_keys_new, self.P)

            # Append to cache buffers (cache management handled by vLLM's block allocator)
            # [actual cache append logic is block-allocator-specific; omitted for clarity]

        # Compute attention using packed binary keys for historical tokens
        # and standard FP16 attention for the current prompt tokens
        attn_output = qjl_packed_attention(
            query=query,
            packed_keys=packed_k_cache,
            outlier_keys=outlier_k,
            packed_values=packed_v_cache,
            outlier_values=outlier_v,
            key_norms=k_norms,
            scale=self.scale,
            proj_matrix=self.P,
            outlier_count=self.outlier_count,
        )
        return attn_output

To register this backend in vLLM before model initialization, patch the backend registry:

# vllm_qjl_patch.py
from vllm.attention import AttentionBackendEnum
from vllm.attention.layer import Attention
from qjl_attention_wrapper import QJLAttentionBackend

def patch_vllm_with_qjl(engine_args, qjl_config: dict):
    """
    Call before vllm.LLM() initialization.
    qjl_config must contain: proj_dim, outlier_count, group_size, projection_matrix
    """
    original_init = Attention.__init__

    def patched_init(self, num_heads, head_dim, scale, **kwargs):
        original_init(self, num_heads, head_dim, scale, **kwargs)
        # Replace the backend with QJL variant post-init
        self.impl = QJLAttentionBackend(
            num_heads=num_heads,
            head_dim=head_dim,
            scale=scale,
            **qjl_config,
        )

    Attention.__init__ = patched_init

Pro-Tip: The qjl_kernels CUDA extension must be compiled against the exact PyTorch ABI version in the inference container. Use torch.utils.cpp_extension.load with verbose=True during the first deployment to validate ABI compatibility before switching to a pre-compiled wheel.

Managing Distributed State Consistency

In Distributed Inference, maintaining cache parity across tensor-parallel shards is essential. Each GPU shard owns a subset of attention heads. The compressed binary KV cache is locally computed and immutable—once a key is projected and sign-packed on a given GPU, it never changes. This property simplifies distributed consistency significantly compared to FP16 caches, which require coordinated updates during speculative decoding rollbacks.

Consistency Protocol Checklist for Multi-GPU Tensor Parallel QJL Inference:

  • [ ] Binary cache sharding: Assign head shards to GPUs at initialization. Each GPU owns num_heads / tensor_parallel_size heads; no cross-GPU communication is needed for cache reads or writes.
  • [ ] Outlier all-gather: Before computing attention output, execute torch.distributed.all_gather on the outlier FP16 tensors across the head-parallel group. The binary cache is local-only.
  • [ ] Norm synchronization: Per-vector L2 norms must be gathered alongside outliers. Pack norms and outlier tensors into a single contiguous buffer to reduce all-gather call overhead.
  • [ ] Projection matrix consistency: The fixed random projection matrix P must be initialized with an identical seed (or broadcast from rank 0) before inference begins. Divergent seeds produce incorrect estimators with no runtime error.
  • [ ] Sequence position alignment: Ensure kv_cache block indices are consistent across ranks before each decode step. vLLM's block manager handles this if the backend is correctly registered; TGI requires manual sequence-length broadcasting.
  • [ ] Rollback handling: On speculative decoding rejection, truncate the binary cache by resetting the sequence-length pointer. No byte-level cleanup is required; the immutable binary blocks are overwritten on the next fill.

Benchmarking and ROI: Measuring VRAM Efficiency

Measuring QJL's production impact requires separating two distinct metrics: VRAM reduction (capacity) and latency per token (throughput). These metrics have different sensitivity curves with respect to sequence length, impacting overall KV Cache Optimization efforts for modern LLM Infrastructure.

Related work under the TurboQuant framework—a production variant of QJL-based compression—reports up to an 8× performance increase in attention logit computation compared to unquantized FP32 keys. This is mechanistically explained by the substitution of floating-point dot products with bitwise XOR-plus-popcount operations: on H100 hardware, integer throughput measured in TOPS is 2–4× higher than FP16 FLOPS for operations that fit the popcount execution path.

Latency benefits are most pronounced when sequence length exceeds 8k tokens. Below this threshold, the attention computation is compute-bound (dominated by Q×K MatMul), and KV cache bandwidth is not the bottleneck. Above 8k tokens, the KV cache no longer fits in L2, and every attention step requires full HBM traversal—this is precisely where 1-bit packing reduces memory traffic by 16× relative to FP16.

Profiling Setup:

# Profile latency per token at varying context lengths
# Requires nsight-systems >= 2024.1 and vLLM with QJL patch applied

nsys profile \
  --trace=cuda,nvtx \
  --output=qjl_profile \
  python benchmark_qjl.py \
    --model meta-llama/Llama-3-8B \
    --seq-lengths 2048 8192 16384 32768 \
    --batch-size 8 \
    --enable-qjl \
    --proj-dim 256 \
    --outlier-count 5

Expected Latency Profile (Llama-3-8B, batch=8, H100 80GB SXM):

Sequence Length | Baseline FP16 (ms/tok) | QJL 1-bit (ms/tok) | VRAM Reduction
----------------|------------------------|---------------------|---------------
2,048           | 18.2                   | 19.4 (+6.6%)        | ~14×
8,192           | 41.7                   | 31.2 (-25.2%)       | ~14×
16,384          | 94.3                   | 58.8 (-37.6%)       | ~14×
32,768          | OOM (A100 80GB)        | 71.4 (fits)         | ~14×

The slight latency regression at short sequences is the cost of the JL projection step at cache-write time—a one-time overhead per token added to the cache. At 8k+ tokens, memory bandwidth savings dominate and latency inverts decisively. At 32k tokens, the baseline does not fit on a single A100 80GB GPU in a batch-8 configuration; QJL enables deployment that would otherwise require a second GPU.

Memory Constraint: Profile VRAM utilization with torch.cuda.memory_stats() before and after enabling QJL. Report reserved_bytes.all.peak rather than allocated_bytes, as the allocator reserves contiguous blocks that may not reflect actual tensor footprint.


Conclusion: Scalability and Future Directions

QJL's production value is in enabling context lengths and batch sizes that FP8 cannot reach on the same hardware. The 1-bit representation is not a soft approximation of INT8; it is a different operating point on the accuracy-memory frontier, justified by the mathematical guarantees of the Johnson-Lindenstrauss lemma. This approach provides a concrete pathway for optimizing Distributed Inference workloads by minimizing cache-related communication overhead.

For teams running distributed inference at scale, the integration path is clear: profile your current VRAM allocation at peak context length, identify the KV cache fraction using nvidia-smi --query-compute-apps, and quantify the batch-size headroom that QJL would unlock. The outlier calibration step (< 30 minutes on a standard calibration corpus) is the only per-model work required.

Pre-Production Engineering Checklist:

  • [ ] Validate Compute Capability ≥ 8.0 on all inference nodes
  • [ ] Compile qjl_kernels CUDA extension against the exact deployment PyTorch version
  • [ ] Run calibrate_outlier_count on model-specific calibration data; confirm outlier_count < 0.05 × head_dim
  • [ ] Verify group_size is a multiple of 32 before baking into kernel build flags
  • [ ] Benchmark perplexity on standard Llama-3-70B benchmarks; reject configurations where degradation exceeds 0.5%
  • [ ] Initialize projection matrix P from a fixed seed and broadcast from rank 0 before first inference request
  • [ ] Confirm all-gather covers outlier tensors and norms in a single fused communication call
  • [ ] Profile latency at target context length with nsys; verify crossover point is within your expected sequence length distribution
  • [ ] Disable QJL for prefill sequences shorter than 4k tokens where projection overhead exceeds bandwidth savings

The next frontier for QJL is integration with speculative decoding and chunked prefill—both of which create non-contiguous cache write patterns that stress the block-packing assumptions. Extending the binary cache block allocator to handle variable-size outlier channels per-layer (rather than a global outlier_count) is the most impactful near-term optimization, as outlier magnitudes vary significantly across transformer layers and current uniform allocation leaves accuracy on the table at identical memory cost.

No meta description set.

Keywords: engineering the quantized johnson-lindenstrauss (qjl) transform for distributed inference

Related articles