General-purpose LLMs trained on web-scale corpora carry an architectural liability into high-stakes legal applications: their embedding spaces treat "force majeure" and "force" as semantically proximate. That proximity is not a rounding error—it directly corrupts judgment prediction, contract analysis, and statutory retrieval. The TermGPT framework addresses this through multi-level contrastive fine-tuning, delivering over 15% improvement in term discrimination accuracy on legal benchmarks. This article dissects exactly how it works and how to implement it.
The Challenge: Why General-Purpose LLMs Fail in Legal Semantics
The isotropy problem is not a metaphor—it has a quantifiable signature. In an isotropic embedding space, the cosine similarity between two arbitrary token vectors approaches a non-zero constant regardless of semantic relationship:
$$S(\mathbf{x}, \mathbf{y}) = \frac{\mathbf{x} \cdot \mathbf{y}}{|\mathbf{x}||\mathbf{y}|}$$
When this value remains high across unrelated tokens—say, 0.87 between "tort" and "tort reform" versus 0.84 between "tort" and "mortgage"—the model has effectively lost resolution. The embeddings cluster in a narrow cone of the high-dimensional space rather than distributing across it. The consequence for LLM optimization in legal contexts is severe: retrieval systems cannot distinguish precedent-critical terms from adjacent vocabulary, and classification heads receive near-identical input representations for semantically distinct legal concepts.
Technical Warning: Isotropy is often invisible to standard perplexity metrics. A model can achieve low perplexity on legal corpora while maintaining a completely degenerate embedding geometry. Always audit embedding distributions with PCA variance-explained ratios or average cosine similarity matrices before assuming a fine-tuned model has resolved this problem.
The mechanics behind isotropy are well-understood. During pretraining on general corpora, frequently co-occurring tokens pull each other's representations into dense, overlapping regions. Legal terminology is low-frequency in web-scale datasets—terms like "subrogation," "estoppel," or "mens rea" appear orders of magnitude less frequently than general English tokens. The model never receives sufficient gradient signal to push their representations into distinct regions of the embedding manifold. Standard supervised fine-tuning on legal documents partially addresses surface-level task performance but does not structurally reorganize the embedding space—it adjusts the task head while leaving the encoder geometry largely intact.
As the TermGPT paper characterizes it: "LLMs often suffer from the isotropy problem, where token embeddings are distributed too uniformly, resulting in diluted semantic resolution for domain-specific terminology."
Mitigation requires a loss objective that explicitly penalizes embedding collapse—contrastive loss applied at both the token and sentence levels simultaneously.
Architecting TermGPT: Moving Beyond Sentence-Level Fine-Tuning
Most contrastive learning adaptations for NLP operate at the sentence level, treating the [CLS] token representation as the optimization target. TermGPT identifies this as insufficient for legal domains where a single sentence can contain multiple high-stakes terms that must be individually discriminated. The architecture therefore operates on two simultaneous levels.
flowchart TD
A[Raw Legal Corpus] --> B[Sentence Graph Construction]
B --> C{Node Classification}
C -->|High-term-density| D[Term Anchor Nodes]
C -->|Contextual| E[Context Nodes]
D --> F[Token-Level Contrastive Objective]
E --> G[Sentence-Level Contrastive Objective]
F --> H[Multi-Level Loss Aggregation]
G --> H
H --> I[LLM Encoder Gradient Update]
I --> J[Anisotropic Embedding Space]
subgraph Positive/Negative Sampling
B --> K[Edge Weights via Lexical Overlap]
K --> L[Positive Pairs: High-Weight Edges]
K --> M[Negative Pairs: Low-Weight / Cross-Domain Edges]
end
L --> F
M --> F
L --> G
M --> G
Figure 1: TermGPT multi-level contrastive framework. Graph construction drives both positive and negative sample generation; losses are computed in parallel at token and sentence levels before aggregation.
At the sentence level, the framework pulls together representations of sentences that share high semantic overlap within the legal domain while repelling semantically dissimilar sentences—a standard SimCSE-style objective but with domain-aware pair generation rather than random dropout augmentation.
At the token level, the framework applies a secondary contrastive objective directly on individual term embeddings extracted from the encoder's intermediate layers. For each identified legal term within a sentence, the model computes a per-token contrastive loss that pushes the term's embedding closer to contextually similar usages (positive samples from the sentence graph) and away from general-domain or cross-domain usages (negative samples).
"We devise a multi-level contrastive learning approach at both the sentence and token levels, enhancing global contextual understanding and fine-grained term discrimination." — TermGPT (AAAI'26)
This dual-objective design is what makes the 15% accuracy improvement structurally achievable rather than incidental. The sentence-level objective preserves coherent global context representation; the token-level objective enforces the local geometric separation that isotropy destroys. Deploying either objective in isolation yields diminishing returns—the sentence-level head cannot enforce per-term precision, and token-level objectives without sentence-level grounding produce fragmented representations.
The implementation requires PyTorch 2.0+ and CUDA 12.0+ to handle the memory and parallelism demands of computing both objectives simultaneously within a single forward pass.
Step-by-Step Implementation: Building the Legal Sentence Graph
The sentence graph is the structural backbone of the entire framework. Nodes represent sentences from the legal corpus; edges represent semantic relatedness weighted by shared legal terminology and lexical overlap. The graph topology directly determines the quality of positive/negative sample pairs.
import networkx as nx
import torch
from transformers import AutoTokenizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from typing import List, Tuple
def build_legal_sentence_graph(
sentences: List[str],
tokenizer_name: str = "nlpaueb/legal-bert-base-uncased",
positive_threshold: float = 0.75,
negative_threshold: float = 0.20,
) -> Tuple[nx.Graph, dict]:
"""
Constructs a sentence graph where edge weights reflect
TF-IDF cosine similarity over legal-domain vocabulary.
High-weight edges produce positive pairs; low-weight edges
produce hard negative pairs, which are more informative
than random negatives for contrastive training.
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# TF-IDF captures term specificity better than raw overlap
# for legal corpora where stopwords are highly frequent
vectorizer = TfidfVectorizer(
tokenizer=lambda x: tokenizer.tokenize(x),
lowercase=True,
max_features=50000,
)
tfidf_matrix = vectorizer.fit_transform(sentences)
sim_matrix = cosine_similarity(tfidf_matrix) # shape: [N, N]
G = nx.Graph()
G.add_nodes_from(range(len(sentences)))
positive_pairs = []
negative_pairs = []
for i in range(len(sentences)):
for j in range(i + 1, len(sentences)):
weight = float(sim_matrix[i, j])
G.add_edge(i, j, weight=weight)
if weight >= positive_threshold:
# Positive pair: sentences sharing significant legal terminology
positive_pairs.append((i, j))
elif weight <= negative_threshold:
# Hard negatives: low overlap but not zero—
# these are more effective than pure random negatives
negative_pairs.append((i, j))
pair_index = {"positive": positive_pairs, "negative": negative_pairs}
return G, pair_index
def prepare_contrastive_tensors(
sentences: List[str],
pair_index: dict,
tokenizer_name: str = "nlpaueb/legal-bert-base-uncased",
max_length: int = 128,
) -> dict:
"""
Tokenizes sentence pairs into PyTorch tensors ready for
injection into the contrastive training loop. Anchor,
positive, and negative tensors are aligned by index.
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def encode_pair_list(pairs: List[Tuple[int, int]]) -> dict:
anchors = [sentences[i] for i, _ in pairs]
counterparts = [sentences[j] for _, j in pairs]
anchor_enc = tokenizer(
anchors,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt",
)
counter_enc = tokenizer(
counterparts,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="pt",
)
return {"anchor": anchor_enc, "counterpart": counter_enc}
return {
"positive_tensors": encode_pair_list(pair_index["positive"]),
"negative_tensors": encode_pair_list(pair_index["negative"]),
}
The TF-IDF vectorizer uses the domain-specific tokenizer vocabulary deliberately—this ensures that subword tokens unique to legal language (e.g., "##tion" following "subrega") contribute to the similarity calculation rather than being absorbed into generic stopword distributions. Hard negatives, defined here as pairs with similarity ≤ 0.20, are more training-efficient than random negatives because they force the model to resolve subtle but real distinctions rather than trivially separating unrelated topics.
Pro-Tip: For production corpora exceeding 500K sentences, replace the dense
cosine_similaritymatrix computation with approximate nearest-neighbor indexing (e.g., FAISS) to keep graph construction tractable. The dense O(N²) computation becomes a bottleneck above this scale even on high-memory GPU nodes.
Negative Sample Management and VRAM Mitigation
Contrastive learning performance scales directly with negative sample quantity and quality. The TermGPT framework's dual-level objectives compound this: each forward pass computes token-level and sentence-level contrastive losses, both requiring their full negative sample batches to be resident in VRAM simultaneously.
import torch
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoModel
import torch.nn as nn
class TermGPTContrastiveTrainer:
def __init__(self, model_name: str, learning_rate: float = 2e-5):
self.model = AutoModel.from_pretrained(
model_name,
# Gradient checkpointing trades ~30% compute overhead
# for proportional VRAM reduction on the activation cache
)
self.model.gradient_checkpointing_enable()
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=learning_rate
)
# BF16 preferred over FP16 on Ampere+ GPUs:
# wider dynamic range prevents loss spikes during
# contrastive objective training
self.scaler = GradScaler()
self.temperature = 0.07 # Standard InfoNCE temperature
def contrastive_loss(
self,
anchor_emb: torch.Tensor,
positive_emb: torch.Tensor,
negative_emb: torch.Tensor,
) -> torch.Tensor:
"""
InfoNCE loss. anchor_emb: [B, H], positive_emb: [B, H],
negative_emb: [B * K, H] where K = negatives per anchor.
"""
# Normalize to unit sphere—critical for isotropy correction
anchor_emb = nn.functional.normalize(anchor_emb, dim=-1)
positive_emb = nn.functional.normalize(positive_emb, dim=-1)
negative_emb = nn.functional.normalize(negative_emb, dim=-1)
pos_sim = torch.sum(anchor_emb * positive_emb, dim=-1) / self.temperature
neg_sim = torch.matmul(anchor_emb, negative_emb.T) / self.temperature
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
return nn.functional.cross_entropy(logits, labels)
def train_step(
self,
anchor_ids: torch.Tensor,
positive_ids: torch.Tensor,
negative_ids: torch.Tensor,
attention_masks: dict,
) -> float:
self.optimizer.zero_grad()
# autocast with bfloat16 on CUDA 12.0+ / Ampere+
with autocast(dtype=torch.bfloat16):
anchor_out = self.model(**attention_masks["anchor"]).last_hidden_state
pos_out = self.model(**attention_masks["positive"]).last_hidden_state
neg_out = self.model(**attention_masks["negative"]).last_hidden_state
# Sentence-level: use [CLS] token (index 0)
sentence_loss = self.contrastive_loss(
anchor_out[:, 0, :], pos_out[:, 0, :], neg_out[:, 0, :]
)
# Token-level: mean-pool identified term token positions
# In production, replace slice with actual term position masks
term_anchor = anchor_out[:, 1:4, :].mean(dim=1)
term_pos = pos_out[:, 1:4, :].mean(dim=1)
term_neg = neg_out[:, 1:4, :].mean(dim=1)
token_loss = self.contrastive_loss(term_anchor, term_pos, term_neg)
# Weighted combination—token loss weighted higher
# to prioritize isotropy correction at term level
total_loss = 0.4 * sentence_loss + 0.6 * token_loss
self.scaler.scale(total_loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
return total_loss.item()
Memory Constraint: On an A100 80GB with batch size 32 and K=64 negatives per anchor, dual-level contrastive training consumes approximately 58–68GB VRAM depending on sequence length. Gradient checkpointing reduces this by ~25%. If VRAM remains insufficient, implement in-batch negative sharing via
torch.distributedacross multiple GPUs usingDistributedDataParallel—this effectively multiplies the negative pool size by the number of GPUs without proportional per-device memory cost.
Evaluating Term Discrimination Accuracy
The JecQA benchmark (Chinese judicial examination QA) provides a rigorous legal-domain evaluation surface because it requires genuine term discrimination—questions cannot be answered through general reasoning alone without precise understanding of specific legal provisions and terminology. TermGPT demonstrates measurably superior F1 performance against vanilla fine-tuning baselines on this benchmark.
A complete evaluation loop should measure three distinct embedding-quality signals:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import f1_score
from typing import List
def evaluate_term_discrimination(
model: AutoModel,
tokenizer: AutoTokenizer,
term_pairs: List[dict], # [{"term": str, "context_pos": str, "context_neg": str, "label": int}]
device: str = "cuda",
) -> dict:
"""
Computes F1 and mean cosine similarity gap between
positive and negative context pairs for each legal term.
A larger cosine gap indicates better isotropy resolution.
"""
model.eval()
model.to(device)
all_preds = []
all_labels = []
cosine_gaps = []
with torch.no_grad():
for item in term_pairs:
enc_pos = tokenizer(
item["context_pos"], return_tensors="pt",
truncation=True, max_length=128
).to(device)
enc_neg = tokenizer(
item["context_neg"], return_tensors="pt",
truncation=True, max_length=128
).to(device)
enc_term = tokenizer(
item["term"], return_tensors="pt",
truncation=True, max_length=32
).to(device)
term_emb = model(**enc_term).last_hidden_state[:, 0, :]
pos_emb = model(**enc_pos).last_hidden_state[:, 0, :]
neg_emb = model(**enc_neg).last_hidden_state[:, 0, :]
sim_pos = torch.nn.functional.cosine_similarity(term_emb, pos_emb).item()
sim_neg = torch.nn.functional.cosine_similarity(term_emb, neg_emb).item()
# Prediction: positive context should score higher
pred = 1 if sim_pos > sim_neg else 0
all_preds.append(pred)
all_labels.append(item["label"])
cosine_gaps.append(sim_pos - sim_neg)
return {
"f1": f1_score(all_labels, all_preds, average="macro"),
"mean_cosine_gap": float(np.mean(cosine_gaps)),
"pct_correct": float(np.mean([p == l for p, l in zip(all_preds, all_labels)])),
}
Track three metrics in parallel: macro F1 on the classification task, mean cosine similarity gap between positive and negative context pairs (directly measures isotropy resolution), and pairwise accuracy. A contrastive-tuned model that has genuinely resolved the isotropy problem will show mean cosine gaps of 0.15–0.30 where a vanilla fine-tuned model produces gaps of 0.02–0.08. That gap differential is the mechanism behind the >15% accuracy improvement—it is not an abstract benchmark number but a direct consequence of embedding geometry correction.
Strategic Integration into Legal Tech Pipelines
A TermGPT-tuned encoder must replace the static embedding model at the retrieval layer of any RAG pipeline serving legal use cases. Plugging the tuned model in as a drop-in replacement is insufficient—the retrieval logic must also account for the now-meaningful cosine distance gradients that the encoder produces.
import torch
from transformers import AutoTokenizer, AutoModel
from typing import List, Tuple
import numpy as np
class LegalRAGPipeline:
def __init__(
self,
encoder_path: str, # Path to TermGPT-tuned checkpoint
document_store: list, # Pre-chunked legal documents
device: str = "cuda",
top_k: int = 5,
):
self.tokenizer = AutoTokenizer.from_pretrained(encoder_path)
self.encoder = AutoModel.from_pretrained(encoder_path).to(device).eval()
self.device = device
self.top_k = top_k
# Pre-compute and cache document embeddings at init time
self.doc_embeddings, self.docs = self._index_documents(document_store)
def _encode(self, texts: List[str]) -> torch.Tensor:
enc = self.tokenizer(
texts, padding=True, truncation=True,
max_length=256, return_tensors="pt"
).to(self.device)
with torch.no_grad():
out = self.encoder(**enc).last_hidden_state[:, 0, :]
# Normalize: mandatory post-tuning to maintain cosine metric validity
return torch.nn.functional.normalize(out, dim=-1)
def _index_documents(
self, documents: List[str]
) -> Tuple[torch.Tensor, List[str]]:
batch_size = 64
all_embs = []
for i in range(0, len(documents), batch_size):
batch = documents[i : i + batch_size]
all_embs.append(self._encode(batch).cpu())
return torch.cat(all_embs, dim=0), documents
def retrieve(self, query: str) -> List[Tuple[str, float]]:
query_emb = self._encode([query]).cpu() # [1, H]
# Dot product on normalized vectors == cosine similarity
scores = torch.matmul(query_emb, self.doc_embeddings.T).squeeze(0)
top_indices = torch.topk(scores, k=self.top_k).indices.tolist()
return [(self.docs[i], float(scores[i])) for i in top_indices]
def generate_context(self, query: str) -> str:
retrieved = self.retrieve(query)
# Filter by minimum cosine threshold—prevents diluted
# context injection that causes hallucinated citations
filtered = [(doc, score) for doc, score in retrieved if score > 0.65]
return "\n\n".join(doc for doc, _ in filtered)
Pro-Tip: The
score > 0.65threshold is not arbitrary—with a properly tuned TermGPT encoder, retrieved documents scoring below this on legal queries are frequently from adjacent domains (general contract language, financial regulation) rather than directly applicable precedent. Calibrate this threshold on a held-out validation set of labeled query-document pairs from your target jurisdiction.
The cosine threshold filter directly addresses the hallucinated citation problem that plagues generic RAG deployments in legal contexts. When the embedding model cannot discriminate "tortious interference" from general tort discussion, the retriever surfaces mixed-relevance chunks that the generation model synthesizes into structurally plausible but legally incorrect citations. The TermGPT encoder's anisotropic geometry makes that threshold meaningful—a vanilla encoder's compressed cosine range renders such filtering nearly ineffective.
Future-Proofing Legal AI Architecture
Embedding anisotropy is not a one-time problem to solve—it is a recurrence risk. As regulatory corpora evolve (new statutory instruments, updated case law, jurisdictional precedent shifts), the sentence graph that drives contrastive training becomes stale. A graph built on 2024 regulatory documents will not capture the semantic neighborhoods of terminology introduced in 2025 legislative cycles. Models running on stale graphs silently regress toward isotropy in the newly introduced terminology regions while maintaining accuracy on historical terms.
Long-term embedding stability requires treating graph reconstruction as a first-class infrastructure concern, not a one-time preprocessing step.
Scaling Checklist for Engineering Leads:
- [ ] Graph reconstruction cadence: Schedule full sentence graph rebuilds quarterly, or triggered by corpus additions exceeding 15% of original training set size
- [ ] Embedding drift monitoring: Implement cosine similarity distribution tracking in production; alert on mean pairwise similarity increases >0.05 across any 30-day window
- [ ] Negative sample refresh: Re-mine hard negatives after each graph rebuild; stale negatives become easy negatives and degrade contrastive signal
- [ ] Checkpoint versioning: Tag each model checkpoint with its graph construction date and corpus snapshot hash for auditability
- [ ] Distributed training readiness: Verify
DistributedDataParallelconfiguration scales to ≥4 GPUs before corpus exceeds 1M sentences; in-batch negative sharing becomes mandatory above this scale - [ ] BF16 verification on target hardware: Confirm CUDA 12.0+ and Ampere/Hopper GPU availability before committing to BF16 mixed precision; fall back to FP16 with loss scaling on Volta-generation hardware
- [ ] Benchmark regression gates: Block deployment of any new checkpoint that does not match or exceed the previous checkpoint's JecQA F1 and mean cosine gap metrics
- [ ] RAG threshold recalibration: After each model update, re-evaluate the cosine retrieval threshold on the held-out validation set; do not assume threshold stability across checkpoint versions
The 15% term discrimination improvement that TermGPT achieves is meaningful precisely because it is grounded in structural geometry correction rather than task-specific overfitting. That structural advantage persists across downstream tasks—the same encoder that improves judgment prediction also improves contract clause retrieval and statutory interpretation without task-specific retraining. Maintaining that advantage over time requires treating the sentence graph and its contrastive objectives as living infrastructure components of the legal AI stack, subject to the same versioning, monitoring, and refresh discipline applied to any production data pipeline.
Keywords: Isotropy Problem, Contrastive Learning, Sentence Graph Construction, TermGPT, Token-level Embeddings, GPU Memory Optimization, JecQA Benchmark, PyTorch DistributedDataParallel, HuggingFace Transformers, Negative Sampling Efficiency