GDPR Article 17's "Right to be Forgotten" creates a direct collision with how large transformer models encode information: parameter weights absorb training data through gradient descent, distributing influence across millions or billions of parameters without a clean lookup table. When a user demands erasure, you cannot simply delete a row from a database. The data is baked into the weights. This article provides a concrete architectural blueprint for implementing machine unlearning in transformer systems, covering exact versus approximate methods, SISA partitioning, catastrophic forgetting mitigation, and the governance pipeline required to make erasure legally defensible.
Architectural Hurdles: Moving Beyond Naive Deletion
Full retraining is the only theoretically complete solution to data erasure in neural networks—and it is operationally untenable at scale. For production LLMs, full retraining often requires thousands of GPU-hours per cycle, making real-time GDPR compliance via retraining non-viable. A single erasure request triggering a multi-week GPU job on a 70B parameter model is not a compliance strategy; it is a liability.
The core architectural friction is that GDPR mandates erasure with defined response windows (typically 30 days under Article 17), while the cost curve for transformer retraining scales with parameter count, dataset size, and infrastructure availability. Organizations must treat machine unlearning as a first-class engineering concern, not a post-hoc patch.
The structural difference between a centralized training pipeline and a SISA-sharded pipeline clarifies why sharding is the correct default architecture for any system ingesting PII at training time:
graph TB
subgraph Centralized["Centralized Training Pipeline"]
D1[Full Dataset] --> T1[Single Training Run]
T1 --> M1[Monolithic Model Weights]
M1 --> E1{Erasure Request}
E1 --> R1[Full Retrain Required]
R1 --> M1
end
subgraph SISA["SISA Sharded Pipeline"]
D2[Full Dataset] --> S1[Shard 1\nSlices A-C]
D2 --> S2[Shard 2\nSlices D-F]
D2 --> S3[Shard N\nSlices G-Z]
S1 --> C1[Checkpoint Store\nShard 1]
S2 --> C2[Checkpoint Store\nShard 2]
S3 --> C3[Checkpoint Store\nShard N]
C1 --> AGG[Aggregated Model]
C2 --> AGG
C3 --> AGG
AGG --> E2{Erasure Request}
E2 --> ID[Locate Target Shard\nvia User-ID Index]
ID --> RT[Retrain Affected\nShard Only]
RT --> C2
C2 --> AGG
end
In the centralized model, any erasure request propagates back to a full training run. In the SISA architecture, the erasure request resolves to a targeted shard re-train—a fraction of the total compute cost. Centralized checkpoint storage for every training shard is a hard infrastructure requirement; without it, even SISA degrades back to full retraining.
As noted in practical deployments, machine unlearning offers a cost-effective alternative, allowing organizations to meet legal and ethical obligations with minimal disruption to operations (LinkedIn, 2026). The architecture that makes this possible starts at data ingestion, not at the erasure request.
Exact vs. Approximate Unlearning: Defining the Compliance Threshold
Machine unlearning splits into two fundamentally different compliance postures: exact unlearning and approximate unlearning. Understanding the distinction is critical before selecting an implementation path.
Exact unlearning guarantees that the unlearned model is statistically indistinguishable from a model retrained from scratch on the dataset minus the target point. This is cryptographically clean for legal purposes but computationally intractable for parameters exceeding 1B. The inverse Hessian computation alone is prohibitive at transformer scale.
Approximate unlearning targets distributional equivalence: the unlearned model's output distribution should approximate what a retrained model would produce. As current research frames it, approximate machine unlearning strives to align the output distribution of unlearned models with that of retrained models (arXiv, 2026). This is the industry standard for production compliance in large-scale systems.
The mathematical basis for influence-based exact unlearning uses the influence function of a data point z on the final model weights θ:
I(z) = -∇θL(z, θ)ᵀ · Hθ⁻¹
Where:
- ∇θL(z, θ) is the gradient of the loss for data point z at parameters θ
- Hθ⁻¹ is the inverse Hessian of the total training loss
This formula quantifies how much a single data point z shifted the final weights. To "undo" that influence, you apply the negative of this quantity as a weight update. The inverse Hessian computation scales as O(p²) in parameter count p—intractable at GPT-scale. Practitioners use Pearlmutter's efficient Hessian-vector product approximation or the LiSSA (Linear time Stochastic Second-order Algorithm) estimator to make this feasible at moderate scales.
The practical compliance threshold: if your model has fewer than ~500M parameters and you have access to the full training set with loss records, influence function unlearning is viable. Above that threshold, SISA or gradient-ascent-based approximate methods are your only operationally realistic options.
Technical Warning: Gradient-ascent unlearning (directly ascending the loss on target data to increase the model's error on those examples) is fast but produces poorly calibrated weight updates. Without constraints on the update magnitude, it destabilizes adjacent weight regions, particularly in attention layers. Always pair gradient ascent with a retain loss term that penalizes deviation from the original model on non-target data.
The Mechanics of SISA for Transformer Shards
SISA (AIModels.fyi, 2026) partitions the training corpus into S shards, each independently trained to produce a sub-model. These sub-models are aggregated (typically by averaging logits or using a voting ensemble) to produce the final model. Each shard is further subdivided into slices, with checkpoints saved at each slice boundary.
When an erasure request arrives, the system locates the shard containing the target user's data, rolls back to the checkpoint immediately before the slice containing that data, and retrains forward from that checkpoint—excluding the target records. Only that shard's weights update; all others remain frozen.
The compliance engineering challenge is state persistence: you must maintain a mapping from user-ID to (shard_id, slice_id, checkpoint_path) and guarantee checkpoint integrity across the model's operational lifetime. SISA frameworks achieve near-identical accuracy to full retraining while reducing retraining time by an order of magnitude when shard granularity is correctly calibrated.
import os
import torch
import hashlib
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from torch.utils.data import Dataset, DataLoader, Subset
@dataclass
class ShardIndex:
"""Maps user_id to shard/slice location and checkpoint path."""
shard_id: int
slice_id: int
checkpoint_path: str
record_indices: List[int] = field(default_factory=list)
class SISAPartitioner:
"""
Manages sharded training data partitioning for SISA-compliant
transformer fine-tuning. Maintains persistent user-to-shard mapping.
"""
def __init__(
self,
dataset: Dataset,
num_shards: int,
num_slices: int,
checkpoint_dir: str,
user_id_field: str = "user_id",
):
self.dataset = dataset
self.num_shards = num_shards
self.num_slices = num_slices
self.checkpoint_dir = checkpoint_dir
self.user_id_field = user_id_field
# user_id -> ShardIndex mapping; must be persisted to durable storage
self.index: Dict[str, ShardIndex] = {}
os.makedirs(checkpoint_dir, exist_ok=True)
def _deterministic_shard(self, user_id: str) -> int:
"""Assign shard deterministically via hashing to avoid rebalancing on re-index."""
digest = hashlib.sha256(user_id.encode()).hexdigest()
return int(digest, 16) % self.num_shards
def build_index(self) -> Dict[int, Dict[int, List[int]]]:
"""
Build shard->slice->record_indices mapping and populate user->shard index.
Returns nested dict: shard_id -> slice_id -> [dataset_indices]
"""
shard_buckets: Dict[int, List[int]] = {i: [] for i in range(self.num_shards)}
for idx in range(len(self.dataset)):
sample = self.dataset[idx]
uid = sample[self.user_id_field]
shard_id = self._deterministic_shard(uid)
shard_buckets[shard_id].append(idx)
# Slice each shard into num_slices segments
shard_slice_map: Dict[int, Dict[int, List[int]]] = {}
for shard_id, indices in shard_buckets.items():
slice_size = max(1, len(indices) // self.num_slices)
slices = {
s: indices[s * slice_size: (s + 1) * slice_size]
for s in range(self.num_slices)
}
shard_slice_map[shard_id] = slices
# Populate reverse index: user_id -> location
for slice_id, slice_indices in slices.items():
for record_idx in slice_indices:
uid = self.dataset[record_idx][self.user_id_field]
ckpt_path = os.path.join(
self.checkpoint_dir,
f"shard_{shard_id}_slice_{slice_id}.pt"
)
self.index[uid] = ShardIndex(
shard_id=shard_id,
slice_id=slice_id,
checkpoint_path=ckpt_path,
record_indices=slice_indices,
)
return shard_slice_map
def get_erasure_target(self, user_id: str) -> Optional[ShardIndex]:
"""Resolve an erasure request to its shard/slice location."""
return self.index.get(user_id)
def get_retrain_subset(
self, shard_slice_map: Dict[int, Dict[int, List[int]]],
target: ShardIndex, erase_indices: List[int]
) -> Subset:
"""
Build the retrain dataset for a shard starting from the target slice,
excluding the records flagged for erasure.
"""
# Collect all indices in shard from target slice onward
retrain_indices = []
for slice_id in range(target.slice_id, self.num_slices):
for idx in shard_slice_map[target.shard_id][slice_id]:
if idx not in erase_indices: # exclude erased records
retrain_indices.append(idx)
return Subset(self.dataset, retrain_indices)
Pro-Tip: Shard count is a performance-compliance tradeoff. More shards reduce retraining cost per erasure but increase aggregation overhead and can reduce model accuracy on low-frequency token distributions. For transformer fine-tuning on domain-specific corpora, 8–16 shards with 4–8 slices each is a practical starting point before profiling.
Performance is inversely proportional to sharding granularity: more shards mean smaller retraining jobs but higher aggregation variance. Calibrate shard count against your erasure SLA and acceptable accuracy budget.
Mitigating Catastrophic Forgetting in Pre-trained Layers
Catastrophic forgetting in transformer unlearning manifests as measurable performance degradation on non-unlearned token distributions—empirically observed at 3–7% in standard benchmarks when unlearning updates are applied without weight constraints. In transformers, attention heads encode long-range semantic dependencies; naive weight updates during unlearning can corrupt these representations for entirely unrelated downstream tasks.
The mechanistic cause: gradient updates targeting PII-containing token distributions propagate through shared weight matrices (particularly Q, K, V projections in multi-head attention) that also encode general language structure. Selective freezing of attention heads is required to prevent model collapse on unrelated semantic tasks during unlearning updates.
The student-teacher distillation approach is the most controlled mitigation: a "teacher" is the original frozen model; a "student" begins from the same weights and updates to maximize KL-divergence on the forget set while minimizing it on the retain set (Trustworthy AI, 2026). Student models update weights to maximize the KL-divergence between student and teacher models for target forget data—driving the model to "unlearn" those distributions while staying anchored to the teacher on everything else.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoModelForCausalLM
from typing import Set
def selective_unlearn_step(
student_model: AutoModelForCausalLM,
teacher_model: AutoModelForCausalLM,
forget_batch: dict,
retain_batch: dict,
frozen_layer_names: Set[str],
forget_weight: float = 1.0,
retain_weight: float = 0.5,
optimizer: AdamW = None,
) -> dict:
"""
Single unlearning gradient step with selective weight freezing.
forget_weight: scales the forget loss (ascent on PII data)
retain_weight: scales the retain loss (anchors non-PII distributions)
frozen_layer_names: set of parameter name substrings to freeze,
e.g. {'layer.0.attention', 'embeddings'}
"""
# Freeze specified layers to protect general representations
for name, param in student_model.named_parameters():
if any(frozen in name for frozen in frozen_layer_names):
param.requires_grad_(False) # prevent gradient flow into frozen heads
else:
param.requires_grad_(True)
student_model.train()
teacher_model.eval()
# --- Forget loss: maximize cross-entropy on PII data ---
forget_inputs = {k: v.to(student_model.device) for k, v in forget_batch.items()}
forget_outputs = student_model(**forget_inputs)
# Gradient ascent: negate the loss to increase model error on PII tokens
forget_loss = -forget_outputs.loss
# --- Retain loss: minimize KL-divergence from teacher on clean data ---
retain_inputs = {k: v.to(student_model.device) for k, v in retain_batch.items()}
with torch.no_grad():
teacher_logits = teacher_model(**retain_inputs).logits
student_retain_outputs = student_model(**retain_inputs)
student_logits = student_retain_outputs.logits
# KL-divergence anchors student to teacher on retain distribution
retain_loss = F.kl_div(
F.log_softmax(student_logits, dim=-1),
F.softmax(teacher_logits, dim=-1),
reduction="batchmean",
)
total_loss = forget_weight * forget_loss + retain_weight * retain_loss
optimizer.zero_grad()
total_loss.backward()
# Gradient clipping prevents destabilizing large weight updates
torch.nn.utils.clip_grad_norm_(
[p for p in student_model.parameters() if p.requires_grad], max_norm=1.0
)
optimizer.step()
return {
"forget_loss": forget_loss.item(),
"retain_loss": retain_loss.item(),
"total_loss": total_loss.item(),
}
Layer freezing strategy matters: freeze the bottom N transformer blocks (which encode general syntactic and semantic features) and permit updates only in the upper layers and task-specific heads. For BERT-class models, freezing the bottom 8 of 12 layers is a reasonable baseline. For decoder-only architectures (GPT-family), freeze embedding layers and the bottom third of transformer blocks absolutely.
Technical Warning: Do not freeze all attention heads. Freezing only the bottom layers while allowing upper-layer Q/K/V updates gives the unlearning optimizer enough degrees of freedom to suppress PII-specific activations without overwriting foundational representations. Full attention freeze with only MLP updates produces under-erasing—the model retains PII-correlated attention patterns even when MLP outputs shift.
Quantifying Performance Degradation on Token Distributions
Measuring unlearning efficacy requires two distinct validation signals: erasure verification (confirming the target data's influence is gone) and utility preservation (confirming non-target performance is intact). KL-divergence provides a scalar metric for both, with alpha and beta controls helping calibrate the divergence tolerance relative to model performance (HuggingFace Blog, 2026).
KL-divergence between original model distribution P and unlearned model distribution Q:
D_KL(P || Q) = Σ P(x) · log(P(x) / Q(x))
Values below 0.05 are typically considered safe for deployment—the distributions are close enough that the unlearned model's behavioral change is within acceptable operational variance. Values above 0.15 signal that the unlearning update has perturbed the model beyond what is attributable to the erasure target alone.
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple
import numpy as np
def compute_token_kl_divergence(
original_model: AutoModelForCausalLM,
unlearned_model: AutoModelForCausalLM,
test_loader: DataLoader, # must be clean, non-PII test data
device: str = "cuda",
max_batches: int = 200,
) -> Tuple[float, float]:
"""
Compute mean and max KL-divergence between original and unlearned model
output distributions over a clean test set.
Returns (mean_kl, max_kl). Deployment gate: mean_kl < 0.05.
"""
original_model.eval()
unlearned_model.eval()
kl_values = []
with torch.no_grad():
for batch_idx, batch in enumerate(test_loader):
if batch_idx >= max_batches:
break
inputs = {k: v.to(device) for k, v in batch.items()
if k in ("input_ids", "attention_mask")}
orig_logits = original_model(**inputs).logits # [B, T, V]
unlearn_logits = unlearned_model(**inputs).logits
# Compute per-token KL; reduce over vocab dim, average over B and T
kl = F.kl_div(
F.log_softmax(unlearn_logits, dim=-1), # Q (unlearned) as log-probs
F.softmax(orig_logits, dim=-1), # P (original) as probs
reduction="none",
).sum(dim=-1) # [B, T]
kl_values.extend(kl.mean(dim=-1).cpu().numpy().tolist())
mean_kl = float(np.mean(kl_values))
max_kl = float(np.max(kl_values))
return mean_kl, max_kl
def run_validation_framework(
original_model: AutoModelForCausalLM,
unlearned_model: AutoModelForCausalLM,
retain_test_loader: DataLoader,
forget_test_loader: DataLoader,
deployment_threshold: float = 0.05,
) -> dict:
"""
Validation logic flow: check utility preservation and erasure effectiveness.
Utility check: KL on retain set should be BELOW threshold (model unchanged).
Erasure check: KL on forget set should be HIGH (model no longer fits PII data).
"""
retain_mean_kl, retain_max_kl = compute_token_kl_divergence(
original_model, unlearned_model, retain_test_loader
)
forget_mean_kl, forget_max_kl = compute_token_kl_divergence(
original_model, unlearned_model, forget_test_loader
)
utility_pass = retain_mean_kl < deployment_threshold
# For erasure verification, higher KL on forget set indicates successful unlearning
erasure_verified = forget_mean_kl > deployment_threshold
return {
"retain_mean_kl": retain_mean_kl,
"retain_max_kl": retain_max_kl,
"forget_mean_kl": forget_mean_kl,
"forget_max_kl": forget_max_kl,
"utility_preservation_pass": utility_pass,
"erasure_verification_pass": erasure_verified,
"deployment_approved": utility_pass and erasure_verified,
}
The validation framework requires a clean, non-PII test set maintained in strict isolation from the training corpus. Running divergence analysis against a contaminated test set violates data privacy constraints and produces misleading metrics—the KL score will reflect PII-correlated token patterns in the test set, not genuine model behavior on general distributions.
Pro-Tip: Run the KL validation framework after every unlearning event, not just on batch cycles. Cumulative unlearning drift—multiple erasure events compounding weight perturbations—can push a model past the 0.05 threshold even when each individual erasure passed validation. Track
retain_mean_klas a time-series metric in your model monitoring system.
Establishing a GDPR-Compliant Data Governance Pipeline
GDPR Article 17 mandates erasure upon request. For models that cannot be retrained from scratch, machine unlearning is currently the only technically viable compliance path. The right to be forgotten creates the core architectural problem (DEV, 2026): AI systems must separate the influence of data from model weights—a separation that never existed in standard training pipelines and must be engineered in from the start.
The governance pipeline must trace every token back to a unique user-ID shard identifier before training begins. Retroactive mapping is not feasible; you cannot reconstruct which weights were influenced by which user's data after the fact without the original training index.
Checkpointing Infrastructure Components:
| Component | Role | Failure Mode Without It |
|---|---|---|
| User-ID ↔ Shard/Slice Index | Maps erasure requests to specific checkpoint locations | Cannot target erasure; must retrain full model |
| Slice-Level Checkpoint Store | Preserves model state at each SISA slice boundary | Re-training from shard start instead of slice start; higher compute cost |
| Immutable Training Audit Log | Records which user IDs contributed to which training run version | Cannot prove erasure to regulators; legal exposure |
| Cryptographic Hash Registry | SHA-256 or BLAKE3 hash of each data slice before and after erasure | No verifiable proof that specific data was excluded from updated weights |
| Erasure Request Queue | Serializes and prioritizes deletion events with SLA tracking | Race conditions between concurrent erasure events; data corruption risk |
| Clean Validation Holdout Store | Non-PII test corpus for post-erasure KL validation | Cannot validate model utility without re-introducing PII risk |
| Model Version Registry | Links deployed model artifacts to their erasure event history | Deployed model may pre-date erasure; compliance gap |
flowchart TD
A[PII Data Ingestion] --> B[Tokenization & User-ID Tagging]
B --> C[Shard Assignment\nDeterministic Hash]
C --> D[Slice Partitioning]
D --> E[Training with Slice Checkpointing]
E --> F[Checkpoint Store\nSlice Boundaries]
E --> G[Cryptographic Hash\nof Each Slice]
G --> H[Hash Registry\nImmutable Log]
F --> I[Deployed Model\nVersion Registry]
I --> J{GDPR Article 17\nErasure Request}
J --> K[Lookup User-ID\nin Shard Index]
K --> L[Rollback to Pre-Slice\nCheckpoint]
L --> M[Retrain Shard\nExcluding Target Records]
M --> N[Post-Erasure\nKL Validation]
N --> O{KL < 0.05?}
O -->|Yes| P[Update Model\nVersion Registry]
O -->|No| Q[Escalate:\nFull Shard Retrain]
P --> R[Update Hash Registry\nAudit Entry]
R --> S[Legal Verification\nCertificate Issued]
The pipeline must operate with strict PII isolation at every stage. Validation holdout stores must be constructed from synthetic or consented non-PII data—running post-erasure validation against real user data to check erasure quality is a compliance violation in itself.
Operationalizing Machine Unlearning in Regulated Environments
At scale, managing training shard identifiers is the dominant operational challenge. In a system ingesting millions of user records, the shard index becomes a critical data asset: it must be durable, consistent, version-controlled, and queryable under SLA. A shard index stored in a mutable database without transactional guarantees will produce incorrect erasure targeting under concurrent write load. Use append-only ledger storage (e.g., Apache Iceberg tables with snapshot isolation, or a dedicated event-sourcing store) for the user-ID to shard mapping.
Regulatory frameworks in financial services and healthcare mandate audit log retention for up to 7 years. This means your unlearning event records—including the original shard state, the erasure request metadata, and the post-erasure validation results—must survive infrastructure migrations, vendor changes, and model deprecations. Decouple audit storage from your model training infrastructure. Managing the complex intersections of GDPR Article 17 and Article 15 compliance requires robust, immutable audit trails to defend AI state decisions during regulatory inquiries (Dark Reading, 2026).
Legal verification requires cryptographic proof that a specific data slice was excluded from the updated model state. Hash the data slice contents before erasure, hash the retrained checkpoint, and record both in an immutable registry. A compliance auditor or data protection authority can then verify that the hash of the erased slice does not appear in the training manifest for the post-erasure model version.
Technical Warning: While GDPR Article 17 and the right of access (Article 15) are the two most operationally demanding compliance vectors for ML systems, they interact: a data subject can first request access (Article 15) to verify their data was used, then request erasure (Article 17). Your governance pipeline must handle both workflows from the same shard index without one workflow's execution corrupting the other's audit trail.
Erasure Auditability Checklist:
- [ ] User-ID to shard/slice mapping stored in append-only, transactionally consistent storage
- [ ] Slice-level checkpoints written with content-addressed naming (hash-based filenames)
- [ ] Pre-erasure data slice hashed and recorded in immutable registry before deletion
- [ ] Erasure request logged with timestamp, requesting party, legal basis, and SLA deadline
- [ ] Post-erasure retrain triggered automatically on SLA schedule with job ID linked to request
- [ ] KL-divergence validation run against clean holdout set; results appended to audit record
- [ ] Post-erasure model checkpoint hashed and registered as new model version
- [ ] Legal verification certificate generated: maps erasure event ID → pre-hash → post-checkpoint hash
- [ ] Audit log written to long-term retention storage (7-year minimum for regulated verticals)
- [ ] Erasure completion notification sent to requesting party with certificate reference ID
- [ ] Shard index updated to reflect removed records; old entries marked tombstoned, not deleted
- [ ] Cumulative KL drift tracked across all erasure events for the model version; escalation threshold set
The shard index tombstone pattern is critical: marking erased records as tombstoned (rather than physically deleting the index entry) preserves the audit trail while preventing those records from being included in future retrain cycles. Physical deletion of the index entry would make it impossible to prove, after the fact, that the record was ever present and subsequently erased.
Organizations operating under multiple regulatory frameworks (GDPR + HIPAA, for example) should expect that erasure SLAs, retention windows, and verification standards will differ between them. Build the governance pipeline to be configurable per regulatory context rather than hardcoding GDPR-specific parameters.
Keywords: SISA (Sharded, Isolated, Sliced, Aggregated), Right to be Forgotten, Catastrophic Forgetting, PII Masking, Influence Functions, Gradient-based Unlearning, Weight Divergence Analysis, GDPR Article 17, Transformer Attention Heads, Checkpointing Infrastructure