ChatGPT generates 10 million conversations per day. Each “regenerate” click is implicit negative feedback — the user rejected the response. Each “thumbs up” is explicit positive feedback. This signal is free, real-time, and reflects actual user preferences instead of annotator proxies. The data flywheel captures this production signal, scrubs PII, converts it to training data, and retrains the model. The result: ChatGPT improves every week without hiring additional annotators. The challenge is engineering: you must process 10M logs per day, remove sensitive data in under 100ms, and retrain without catastrophic forgetting.
This post covers the complete data flywheel pipeline: log collection, PII scrubbing, signal extraction, preference pair construction, quality filtering, and online DPO integration.
Production Log Collection Architecture
Log Schema Design
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional
class InteractionSignal(Enum):
THUMBS_UP = "thumbs_up"
THUMBS_DOWN = "thumbs_down"
REGENERATE = "regenerate"
COPY = "copy"
EDIT = "edit"
ABANDON = "abandon"
CONTINUE = "continue"
SHARE = "share"
@dataclass
class ProductionLog:
"""Single interaction log entry."""
request_id: str
timestamp: datetime
prompt: str
response: str
model_version: str
latency_ms: float
token_count_prompt: int
token_count_response: int
signals: list = field(default_factory=list)
session_id: str = ""
user_segment: str = "unknown"
temperature: float = 0.7
top_p: float = 1.0
is_multi_turn: bool = False
turn_index: int = 0
@dataclass
class ImplicitPreference:
"""Preference pair derived from production signals."""
prompt: str
chosen: str
rejected: str
signal_source: str
confidence: float
timestamp: datetime
class LogCollector:
"""
Collects production logs from inference servers.
Designed for high-throughput async collection
with batched writes to object storage.
"""
def __init__(self, config):
self.buffer_size = config.get("buffer_size", 10000)
self.flush_interval_s = config.get("flush_interval_s", 60)
self.buffer = []
self.storage_backend = config.get("storage", "s3")
self.partition_key = config.get(
"partition_key", "date"
)
def ingest(self, log_entry):
"""
Ingest a single log entry.
Applies immediate PII detection before buffering.
Entries flagged as high-PII-risk are quarantined
rather than buffered.
"""
pii_risk = self._quick_pii_check(log_entry.prompt)
if pii_risk > 0.8:
self._quarantine(log_entry)
return
self.buffer.append(log_entry)
if len(self.buffer) >= self.buffer_size:
self._flush()
def _quick_pii_check(self, text):
"""
Fast regex-based PII detection.
Checks for common patterns: SSN, credit card,
email, phone, IP address. Returns risk score 0-1.
Full NER-based detection happens downstream.
"""
import re
patterns = {
"ssn": r"\b\d{3}-\d{2}-\d{4}\b",
"credit_card": r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b",
"email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
"phone": r"\b\d{3}[-.]?\d{3}[-.]?\d{4}\b",
"ip_address": r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b",
}
matches = 0
for pattern_name, pattern in patterns.items():
if re.search(pattern, text):
matches += 1
return min(matches / 2.0, 1.0)
def _quarantine(self, log_entry):
"""Send to quarantine storage for manual review."""
pass
def _flush(self):
"""Write buffer to object storage."""
pass
Production logs contain user data. Every pipeline component must handle PII: regex-based fast filters at ingestion, NER-based deep scrubbing before storage, and differential privacy noise injection before any data enters training. GDPR Article 17 (right to erasure) requires the ability to delete specific user interactions from all downstream datasets.
Signal Extraction from User Behavior
class SignalExtractor:
"""
Extract training signals from user behavior patterns.
Different signals have different reliability levels.
Thumbs-up/down is explicit but sparse (1-5% of interactions).
Regeneration is implicit but high-signal (user was unsatisfied).
Copy events suggest the response was useful.
"""
SIGNAL_WEIGHTS = {
InteractionSignal.THUMBS_UP: 1.0,
InteractionSignal.THUMBS_DOWN: -1.0,
InteractionSignal.REGENERATE: -0.7,
InteractionSignal.COPY: 0.6,
InteractionSignal.EDIT: 0.3,
InteractionSignal.ABANDON: -0.4,
InteractionSignal.CONTINUE: 0.5,
InteractionSignal.SHARE: 0.8,
}
SIGNAL_CONFIDENCE = {
InteractionSignal.THUMBS_UP: 0.95,
InteractionSignal.THUMBS_DOWN: 0.90,
InteractionSignal.REGENERATE: 0.80,
InteractionSignal.COPY: 0.60,
InteractionSignal.EDIT: 0.50,
InteractionSignal.ABANDON: 0.40,
InteractionSignal.CONTINUE: 0.55,
InteractionSignal.SHARE: 0.85,
}
def compute_interaction_score(self, log_entry):
"""
Compute a scalar quality score from all signals
on a single interaction.
"""
if not log_entry.signals:
return 0.0, 0.0
weighted_sum = 0.0
total_confidence = 0.0
for signal in log_entry.signals:
weight = self.SIGNAL_WEIGHTS.get(signal, 0.0)
confidence = self.SIGNAL_CONFIDENCE.get(signal, 0.5)
weighted_sum += weight * confidence
total_confidence += confidence
if total_confidence == 0:
return 0.0, 0.0
score = weighted_sum / total_confidence
avg_confidence = total_confidence / len(log_entry.signals)
return score, avg_confidence
def extract_preference_pairs(self, session_logs):
"""
Extract preference pairs from regeneration events.
When a user regenerates, the original response is
'rejected' and the final accepted response is 'chosen'.
This produces DPO-compatible training pairs.
"""
pairs = []
sorted_logs = sorted(
session_logs, key=lambda x: x.timestamp
)
i = 0
while i < len(sorted_logs):
log = sorted_logs[i]
if InteractionSignal.REGENERATE in log.signals:
rejected = log.response
# Find the accepted response (next non-regenerated)
j = i + 1
while (
j < len(sorted_logs)
and InteractionSignal.REGENERATE
in sorted_logs[j].signals
):
j += 1
if j < len(sorted_logs):
chosen = sorted_logs[j].response
pairs.append(
ImplicitPreference(
prompt=log.prompt,
chosen=chosen,
rejected=rejected,
signal_source="regeneration",
confidence=0.8,
timestamp=log.timestamp,
)
)
i += 1
return pairs
def extract_from_edits(self, log_entry, edited_text):
"""
When a user edits a response, the original is
'rejected' and the edited version is 'chosen'.
Lower confidence than regeneration because edits
may be minor formatting changes.
"""
if not edited_text or edited_text == log_entry.response:
return None
# Compute edit distance to gauge significance
import difflib
ratio = difflib.SequenceMatcher(
None, log_entry.response, edited_text
).ratio()
# If edit is trivial (> 95% similar), skip
if ratio > 0.95:
return None
confidence = min(0.9, 1.0 - ratio + 0.3)
return ImplicitPreference(
prompt=log_entry.prompt,
chosen=edited_text,
rejected=log_entry.response,
signal_source="user_edit",
confidence=confidence,
timestamp=log_entry.timestamp,
)
Signal Type Reliability for Preference Extraction
| Signal Type | Frequency (% of interactions) | Precision as Preference | Recall | F1 | Notes |
|---|---|---|---|---|---|
| Thumbs down + Regenerate | 2-5% | 0.92 | 0.15 | 0.26 | Highest quality but rare |
| Regenerate only | 8-15% | 0.80 | 0.35 | 0.49 | User explicitly unsatisfied |
| User edit | 3-8% | 0.70 | 0.20 | 0.31 | Significance depends on edit distance |
| Copy event | 15-25% | 0.60 | 0.50 | 0.55 | Positive signal, high coverage |
| Session abandon | 20-30% | 0.40 | 0.60 | 0.48 | Noisy -- many reasons for abandonment |
| Thumbs up | 1-3% | 0.95 | 0.08 | 0.15 | Highest precision, very sparse |
PII Scrubbing at Scale
Multi-Stage PII Pipeline
import re
from dataclasses import dataclass
@dataclass
class PIIDetection:
entity_type: str
start: int
end: int
confidence: float
original_text: str
class PIIScrubber:
"""
Multi-stage PII removal pipeline.
Stage 1: Regex patterns for structured PII
(SSN, credit card, phone, email).
Stage 2: NER model for unstructured PII
(names, addresses, organizations).
Stage 3: Contextual PII detection
(medical conditions linked to identifiers).
Stage 4: k-anonymity verification on the output.
"""
REGEX_PATTERNS = {
"SSN": r"\b\d{3}-\d{2}-\d{4}\b",
"CREDIT_CARD": (
r"\b(?:4\d{3}|5[1-5]\d{2}|3[47]\d{2}|6(?:011|5\d{2}))"
r"[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{1,4}\b"
),
"EMAIL": (
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
),
"PHONE_US": r"\b(?:\+1[\s-]?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b",
"IP_ADDRESS": (
r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}"
r"(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b"
),
"DATE_OF_BIRTH": (
r"\b(?:0[1-9]|1[0-2])[/-](?:0[1-9]|[12]\d|3[01])"
r"[/-](?:19|20)\d{2}\b"
),
}
REPLACEMENT_MAP = {
"SSN": "[SSN_REDACTED]",
"CREDIT_CARD": "[CC_REDACTED]",
"EMAIL": "[EMAIL_REDACTED]",
"PHONE_US": "[PHONE_REDACTED]",
"IP_ADDRESS": "[IP_REDACTED]",
"DATE_OF_BIRTH": "[DOB_REDACTED]",
"PERSON": "[NAME_REDACTED]",
"LOCATION": "[LOCATION_REDACTED]",
"ORGANIZATION": "[ORG_REDACTED]",
}
def __init__(self, ner_model=None):
self.ner_model = ner_model
def scrub_regex(self, text):
"""Stage 1: Fast regex-based PII removal."""
detections = []
for entity_type, pattern in self.REGEX_PATTERNS.items():
for match in re.finditer(pattern, text):
detections.append(
PIIDetection(
entity_type=entity_type,
start=match.start(),
end=match.end(),
confidence=0.95,
original_text=match.group(),
)
)
# Apply replacements in reverse order to preserve indices
detections.sort(key=lambda d: d.start, reverse=True)
scrubbed = text
for det in detections:
replacement = self.REPLACEMENT_MAP.get(
det.entity_type, "[REDACTED]"
)
scrubbed = (
scrubbed[:det.start]
+ replacement
+ scrubbed[det.end:]
)
return scrubbed, detections
def scrub_ner(self, text):
"""
Stage 2: NER-based PII removal.
Uses a fine-tuned NER model to detect names,
addresses, and other unstructured PII.
Slower than regex but catches cases like
'My neighbor John Smith at 123 Oak Street'.
"""
if self.ner_model is None:
return text, []
entities = self.ner_model.predict(text)
detections = []
for entity in entities:
if entity["label"] in ("PERSON", "LOCATION", "ORG"):
detections.append(
PIIDetection(
entity_type=entity["label"],
start=entity["start"],
end=entity["end"],
confidence=entity["score"],
original_text=text[
entity["start"]:entity["end"]
],
)
)
detections.sort(key=lambda d: d.start, reverse=True)
scrubbed = text
for det in detections:
if det.confidence > 0.7:
replacement = self.REPLACEMENT_MAP.get(
det.entity_type, "[REDACTED]"
)
scrubbed = (
scrubbed[:det.start]
+ replacement
+ scrubbed[det.end:]
)
return scrubbed, detections
def scrub_full_pipeline(self, text):
"""Run all PII scrubbing stages sequentially."""
all_detections = []
# Stage 1: Regex
text, regex_dets = self.scrub_regex(text)
all_detections.extend(regex_dets)
# Stage 2: NER
text, ner_dets = self.scrub_ner(text)
all_detections.extend(ner_dets)
return text, all_detections
PII scrubbing is not optional. A single leaked SSN or medical record in training data creates legal liability under HIPAA, GDPR, and CCPA. The regex stage catches 85-90% of structured PII. The NER stage catches an additional 5-8%. The remaining 2-7% requires manual review or conservative blanking strategies that redact any text matching high-risk patterns even at lower confidence.
Quality Filtering Production Data
Automated Quality Scoring
import hashlib
import numpy as np
from collections import defaultdict
class ProductionDataFilter:
"""
Filters production logs for training quality.
Not all production interactions produce good training data.
Filter criteria:
- Minimum response length (avoid trivial completions)
- Language quality (avoid garbled or truncated outputs)
- Deduplication (production traffic has heavy repeats)
- Topic diversity (avoid over-representing popular queries)
- Safety (exclude toxic/harmful content)
"""
def __init__(self, config):
self.min_response_tokens = config.get(
"min_response_tokens", 50
)
self.max_response_tokens = config.get(
"max_response_tokens", 4096
)
self.dedup_threshold = config.get(
"dedup_threshold", 0.85
)
self.topic_cap = config.get("topic_cap", 1000)
self.seen_hashes = set()
self.topic_counts = defaultdict(int)
def filter_batch(self, logs):
"""
Filter a batch of production logs.
Returns (accepted, rejected_with_reasons).
"""
accepted = []
rejected = []
for log in logs:
reasons = self._check_quality(log)
if not reasons:
accepted.append(log)
else:
rejected.append((log, reasons))
return accepted, rejected
def _check_quality(self, log):
"""
Run all quality checks. Returns list of failure
reasons (empty list means passed all checks).
"""
reasons = []
# Length check
if log.token_count_response < self.min_response_tokens:
reasons.append(
f"Too short: {log.token_count_response} tokens"
)
if log.token_count_response > self.max_response_tokens:
reasons.append(
f"Too long: {log.token_count_response} tokens"
)
# Deduplication
content_hash = hashlib.md5(
(log.prompt + log.response).encode()
).hexdigest()
if content_hash in self.seen_hashes:
reasons.append("Duplicate")
else:
self.seen_hashes.add(content_hash)
# Repetition detection
if self._has_excessive_repetition(log.response):
reasons.append("Excessive repetition")
# Truncation detection
if self._is_truncated(log.response):
reasons.append("Truncated response")
return reasons
def _has_excessive_repetition(self, text):
"""
Detect repetitive text patterns.
Checks n-gram repetition ratio. If any 4-gram
appears more than 10% of total 4-grams, flag it.
"""
words = text.split()
if len(words) < 20:
return False
ngram_counts = defaultdict(int)
for i in range(len(words) - 3):
ngram = " ".join(words[i:i+4])
ngram_counts[ngram] += 1
total_ngrams = len(words) - 3
max_count = max(ngram_counts.values())
return max_count / total_ngrams > 0.10
def _is_truncated(self, text):
"""
Detect truncated responses.
Heuristics: ends mid-sentence, ends with
incomplete code block, ends with ellipsis
from max-token cutoff.
"""
text = text.strip()
if not text:
return True
# Ends mid-word (no final punctuation or closing)
last_char = text[-1]
if last_char.isalpha() and len(text) > 100:
# Check if last sentence is incomplete
last_period = text.rfind(".")
if last_period > 0:
remaining = text[last_period + 1:].strip()
if len(remaining.split()) > 15:
return True
# Unclosed code blocks
open_blocks = text.count("```")
if open_blocks % 2 != 0:
return True
return False
Diversity Sampling
class DiversitySampler:
"""
Ensure topic diversity in flywheel data.
Production traffic is heavily skewed: 20% of query
types account for 80% of volume. Without diversity
sampling, the model over-trains on popular topics
and degrades on the long tail.
"""
def __init__(self, embedding_model, n_clusters=500):
self.embedding_model = embedding_model
self.n_clusters = n_clusters
self.cluster_counts = defaultdict(int)
self.cluster_cap = 200
self.cluster_centers = None
def fit_clusters(self, sample_prompts):
"""
Fit topic clusters on a sample of prompts.
Uses k-means on prompt embeddings to discover
topic clusters. Run once per week on a random
sample of 100K prompts.
"""
from sklearn.cluster import MiniBatchKMeans
embeddings = self.embedding_model.encode(
sample_prompts
)
kmeans = MiniBatchKMeans(
n_clusters=self.n_clusters,
batch_size=1000,
)
kmeans.fit(embeddings)
self.cluster_centers = kmeans.cluster_centers_
return kmeans
def should_accept(self, prompt):
"""
Accept/reject based on cluster saturation.
If this prompt's cluster has reached its cap,
accept with probability inversely proportional
to over-representation.
"""
if self.cluster_centers is None:
return True
embedding = self.embedding_model.encode([prompt])
distances = np.linalg.norm(
self.cluster_centers - embedding, axis=1
)
cluster_id = int(np.argmin(distances))
current_count = self.cluster_counts[cluster_id]
if current_count < self.cluster_cap:
self.cluster_counts[cluster_id] += 1
return True
# Over-represented: accept with decaying probability
accept_prob = self.cluster_cap / (current_count + 1)
if np.random.random() < accept_prob:
self.cluster_counts[cluster_id] += 1
return True
return False
Quality Filter Impact on Downstream Training
| Filter Stage | Data Retained (%) | DPO Win Rate After Training | Avg Response Quality (1-5) | Training Tokens |
|---|---|---|---|---|
| No filtering (raw logs) | 100% | 48% | 3.2 | 50B |
| + Length filter | 72% | 50% | 3.4 | 36B |
| + Deduplication | 45% | 52% | 3.5 | 22.5B |
| + Repetition filter | 41% | 53% | 3.6 | 20.5B |
| + Diversity sampling | 28% | 56% | 3.8 | 14B |
| + Quality scoring | 15% | 59% | 4.1 | 7.5B |
Preference Pair Construction
From Implicit Signals to DPO Pairs
class PreferencePairBuilder:
"""
Construct DPO-compatible preference pairs from
production signals.
DPO requires (prompt, chosen, rejected) triples.
Sources:
1. Regeneration: original = rejected, final = chosen
2. Thumbs up/down on A/B tests
3. User edits: original = rejected, edited = chosen
4. Model comparison: serve 2 models, user picks one
"""
def __init__(self, reward_model=None):
self.reward_model = reward_model
self.min_confidence = 0.6
def build_from_regenerations(self, session_logs):
"""
Extract pairs from regeneration chains.
A regeneration chain: user sends prompt, gets response A,
regenerates to get B, regenerates again to get C, then
accepts C. This produces pairs:
(prompt, C, A) with confidence 0.8
(prompt, C, B) with confidence 0.7
"""
pairs = []
chains = self._find_regeneration_chains(session_logs)
for chain in chains:
accepted = chain[-1] # Last response is accepted
for i, rejected_log in enumerate(chain[:-1]):
# Confidence decreases for earlier rejections
# (user may not have read them fully)
position_factor = 1.0 - (
0.1 * (len(chain) - 2 - i)
)
confidence = 0.8 * max(position_factor, 0.5)
pairs.append(
ImplicitPreference(
prompt=rejected_log.prompt,
chosen=accepted.response,
rejected=rejected_log.response,
signal_source="regeneration_chain",
confidence=confidence,
timestamp=accepted.timestamp,
)
)
return pairs
def _find_regeneration_chains(self, session_logs):
"""Group logs into regeneration chains by prompt."""
chains = []
current_chain = []
for log in sorted(
session_logs, key=lambda x: x.timestamp
):
if not current_chain:
current_chain.append(log)
continue
# Same prompt means regeneration
if log.prompt == current_chain[-1].prompt:
current_chain.append(log)
else:
if len(current_chain) > 1:
chains.append(current_chain)
current_chain = [log]
if len(current_chain) > 1:
chains.append(current_chain)
return chains
def build_from_ab_tests(self, ab_test_logs):
"""
Extract pairs from A/B test results.
In A/B testing, users see responses from two models
and pick the better one. This is the highest-quality
signal source but requires infrastructure support.
"""
pairs = []
for test in ab_test_logs:
if test.winner is None:
continue
chosen = (
test.response_a
if test.winner == "A"
else test.response_b
)
rejected = (
test.response_b
if test.winner == "A"
else test.response_a
)
pairs.append(
ImplicitPreference(
prompt=test.prompt,
chosen=chosen,
rejected=rejected,
signal_source="ab_test",
confidence=0.95,
timestamp=test.timestamp,
)
)
return pairs
def validate_pairs(self, pairs):
"""
Validate preference pairs using a reward model.
If the reward model disagrees with the implicit
signal, the pair is likely noisy. Remove pairs
where reward model scores the 'rejected' higher
by a large margin.
"""
if self.reward_model is None:
return pairs
validated = []
for pair in pairs:
chosen_score = self.reward_model.score(
pair.prompt, pair.chosen
)
rejected_score = self.reward_model.score(
pair.prompt, pair.rejected
)
# If reward model agrees, keep the pair
if chosen_score > rejected_score:
pair.confidence = min(
pair.confidence * 1.1, 1.0
)
validated.append(pair)
elif chosen_score > rejected_score - 0.5:
# Slight disagreement: keep but lower confidence
pair.confidence *= 0.7
if pair.confidence >= self.min_confidence:
validated.append(pair)
# Large disagreement: discard the pair
return validated
Reward model validation of preference pairs is critical. In production, 15-25% of implicit preference pairs are noisy (the rejected response is actually better by objective measures). Without validation, these noisy pairs degrade DPO training. A reward model filter removes most noise at the cost of discarding 10-15% of valid pairs.
Online DPO Integration
Incremental Training from Production Data
class OnlineDPOTrainer:
"""
Incrementally train on production preference pairs.
Full retraining on the complete dataset is expensive.
Online DPO accumulates production preference pairs
in a replay buffer and trains in small batches,
mixing new production data with retained offline data.
"""
def __init__(self, model, tokenizer, config):
self.model = model
self.tokenizer = tokenizer
self.learning_rate = config.get("lr", 1e-6)
self.beta = config.get("dpo_beta", 0.1)
self.replay_buffer_size = config.get(
"replay_buffer_size", 100000
)
self.batch_size = config.get("batch_size", 8)
self.offline_ratio = config.get("offline_ratio", 0.3)
self.replay_buffer = []
self.offline_data = []
self.update_count = 0
def add_pairs(self, pairs):
"""
Add new preference pairs to the replay buffer.
Oldest pairs are evicted when buffer is full.
Pairs are stored with their confidence scores
for weighted sampling.
"""
for pair in pairs:
entry = {
"prompt": pair.prompt,
"chosen": pair.chosen,
"rejected": pair.rejected,
"confidence": pair.confidence,
"timestamp": pair.timestamp,
"source": pair.signal_source,
}
self.replay_buffer.append(entry)
# Evict oldest if over capacity
if len(self.replay_buffer) > self.replay_buffer_size:
self.replay_buffer = self.replay_buffer[
-self.replay_buffer_size:
]
def sample_batch(self):
"""
Sample a training batch mixing online and offline data.
Offline data provides stability (prevents forgetting).
Online data provides adaptation (tracks distribution shift).
Confidence-weighted sampling prioritizes high-quality pairs.
"""
n_offline = int(self.batch_size * self.offline_ratio)
n_online = self.batch_size - n_offline
# Sample offline data uniformly
offline_batch = []
if self.offline_data and n_offline > 0:
indices = np.random.choice(
len(self.offline_data),
size=min(n_offline, len(self.offline_data)),
replace=False,
)
offline_batch = [self.offline_data[i] for i in indices]
# Sample online data weighted by confidence
online_batch = []
if self.replay_buffer and n_online > 0:
confidences = np.array([
entry["confidence"]
for entry in self.replay_buffer
])
probs = confidences / confidences.sum()
indices = np.random.choice(
len(self.replay_buffer),
size=min(n_online, len(self.replay_buffer)),
replace=False,
p=probs,
)
online_batch = [
self.replay_buffer[i] for i in indices
]
return offline_batch + online_batch
def compute_dpo_loss(self, batch):
"""
Compute DPO loss for a batch.
L_DPO = -E[log sigma(beta * (log pi(y_w|x)/pi_ref(y_w|x)
- log pi(y_l|x)/pi_ref(y_l|x)))]
where y_w is chosen, y_l is rejected, pi is the
current policy, pi_ref is the reference model.
"""
total_loss = 0.0
for entry in batch:
# Tokenize
chosen_ids = self.tokenizer.encode(
entry["prompt"] + entry["chosen"]
)
rejected_ids = self.tokenizer.encode(
entry["prompt"] + entry["rejected"]
)
# Forward pass (simplified)
chosen_logprob = self._compute_logprob(chosen_ids)
rejected_logprob = self._compute_logprob(
rejected_ids
)
chosen_ref_logprob = self._compute_ref_logprob(
chosen_ids
)
rejected_ref_logprob = self._compute_ref_logprob(
rejected_ids
)
# DPO loss
chosen_reward = (
self.beta
* (chosen_logprob - chosen_ref_logprob)
)
rejected_reward = (
self.beta
* (rejected_logprob - rejected_ref_logprob)
)
loss = -np.log(
1.0 / (1.0 + np.exp(-(chosen_reward - rejected_reward)))
)
# Weight by confidence
loss *= entry.get("confidence", 1.0)
total_loss += loss
return total_loss / len(batch)
def _compute_logprob(self, token_ids):
"""Compute log probability under current policy."""
return 0.0 # Placeholder
def _compute_ref_logprob(self, token_ids):
"""Compute log probability under reference model."""
return 0.0 # Placeholder
def train_step(self):
"""Execute one training step."""
batch = self.sample_batch()
if not batch:
return None
loss = self.compute_dpo_loss(batch)
self.update_count += 1
return {
"loss": loss,
"update_count": self.update_count,
"buffer_size": len(self.replay_buffer),
"batch_online_ratio": 1.0 - self.offline_ratio,
}
Online DPO Win Rate Over Time (vs Static Model)
| Metric | 0 | 7 | 14 | 21 | 28 | 35 | 42 | 56 | 70 | 90 |
|---|---|---|---|---|---|---|---|---|---|---|
| Online DPO (production pairs) | ||||||||||
| Weekly batch retrain | ||||||||||
| No flywheel (static) |
Distribution Shift Detection
Monitoring and Adaptation
from collections import deque
class DistributionShiftDetector:
"""
Detect when production query distribution shifts
significantly from training distribution.
Shift detection triggers re-sampling of the offline
data mix and can increase the online data ratio
in the DPO trainer.
"""
def __init__(self, embedding_model, window_size=10000):
self.embedding_model = embedding_model
self.window_size = window_size
self.reference_embeddings = None
self.current_window = deque(maxlen=window_size)
self.shift_history = []
def set_reference(self, reference_prompts):
"""
Set reference distribution from training data.
Compute mean and covariance of embeddings
from the training set.
"""
embeddings = self.embedding_model.encode(
reference_prompts
)
self.reference_mean = np.mean(embeddings, axis=0)
self.reference_cov = np.cov(embeddings.T)
# Regularize covariance for numerical stability
self.reference_cov += (
np.eye(self.reference_cov.shape[0]) * 1e-6
)
self.reference_cov_inv = np.linalg.inv(
self.reference_cov
)
def add_production_sample(self, prompt):
"""Add a production prompt to the sliding window."""
embedding = self.embedding_model.encode([prompt])[0]
self.current_window.append(embedding)
def compute_shift_score(self):
"""
Compute distribution shift using Mahalanobis distance
between current window mean and reference mean.
Also computes Maximum Mean Discrepancy (MMD) as a
secondary metric.
"""
if len(self.current_window) < 100:
return None
current = np.array(list(self.current_window))
current_mean = np.mean(current, axis=0)
# Mahalanobis distance
diff = current_mean - self.reference_mean
mahal_dist = np.sqrt(
diff @ self.reference_cov_inv @ diff
)
# MMD with RBF kernel (simplified)
n_ref = min(1000, len(current))
ref_sample = self.reference_mean.reshape(1, -1)
prod_sample = current[:n_ref]
mmd = self._compute_mmd(ref_sample, prod_sample)
return {
"mahalanobis_distance": float(mahal_dist),
"mmd": float(mmd),
"window_size": len(self.current_window),
"shift_detected": mahal_dist > 3.0 or mmd > 0.1,
}
def _compute_mmd(self, x, y, bandwidth=1.0):
"""
Maximum Mean Discrepancy with RBF kernel.
MMD^2 = E[k(x,x')] - 2*E[k(x,y)] + E[k(y,y')]
"""
def rbf_kernel(a, b):
diff = np.expand_dims(a, 1) - np.expand_dims(b, 0)
sq_dist = np.sum(diff ** 2, axis=-1)
return np.exp(-sq_dist / (2 * bandwidth ** 2))
k_xx = rbf_kernel(x, x)
k_yy = rbf_kernel(y, y)
k_xy = rbf_kernel(x, y)
mmd_sq = (
np.mean(k_xx)
- 2 * np.mean(k_xy)
+ np.mean(k_yy)
)
return max(0.0, mmd_sq) ** 0.5
Complete Flywheel Pipeline
End-to-End Orchestration
class DataFlywheelPipeline:
"""
Complete data flywheel: production logs -> training data.
Pipeline stages:
1. Collect: Ingest logs from inference servers
2. Scrub: Remove PII
3. Filter: Quality and diversity checks
4. Extract: Build preference pairs from signals
5. Validate: Reward model cross-check
6. Train: Online DPO with replay buffer
7. Deploy: A/B test new model against current
8. Monitor: Track shift, quality, and win rate
"""
def __init__(self, config):
self.collector = LogCollector(config["collector"])
self.scrubber = PIIScrubber()
self.quality_filter = ProductionDataFilter(
config["filter"]
)
self.signal_extractor = SignalExtractor()
self.pair_builder = PreferencePairBuilder()
self.dpo_trainer = OnlineDPOTrainer(
model=config["model"],
tokenizer=config["tokenizer"],
config=config["trainer"],
)
self.shift_detector = DistributionShiftDetector(
embedding_model=config["embedding_model"],
)
self.metrics = FlywheelMetrics()
def run_daily_batch(self, raw_logs):
"""
Process one day of production logs through
the full pipeline.
"""
# Stage 1: PII scrubbing
scrubbed_logs = []
pii_stats = {"total_detections": 0, "quarantined": 0}
for log in raw_logs:
clean_prompt, prompt_dets = (
self.scrubber.scrub_full_pipeline(log.prompt)
)
clean_response, response_dets = (
self.scrubber.scrub_full_pipeline(log.response)
)
pii_stats["total_detections"] += (
len(prompt_dets) + len(response_dets)
)
log.prompt = clean_prompt
log.response = clean_response
scrubbed_logs.append(log)
# Stage 2: Quality filtering
accepted, rejected = self.quality_filter.filter_batch(
scrubbed_logs
)
# Stage 3: Signal extraction
all_pairs = []
sessions = self._group_by_session(accepted)
for session_id, session_logs in sessions.items():
pairs = (
self.signal_extractor
.extract_preference_pairs(session_logs)
)
all_pairs.extend(pairs)
# Stage 4: Validation
validated_pairs = self.pair_builder.validate_pairs(
all_pairs
)
# Stage 5: Add to DPO trainer
self.dpo_trainer.add_pairs(validated_pairs)
# Stage 6: Training steps
train_results = []
n_steps = min(
len(validated_pairs) // self.dpo_trainer.batch_size,
100,
)
for _ in range(n_steps):
result = self.dpo_trainer.train_step()
if result:
train_results.append(result)
# Stage 7: Distribution shift monitoring
for log in accepted[:1000]:
self.shift_detector.add_production_sample(
log.prompt
)
shift = self.shift_detector.compute_shift_score()
return {
"raw_logs": len(raw_logs),
"after_pii_scrub": len(scrubbed_logs),
"after_quality_filter": len(accepted),
"preference_pairs_extracted": len(all_pairs),
"preference_pairs_validated": len(validated_pairs),
"training_steps": len(train_results),
"avg_loss": (
np.mean([r["loss"] for r in train_results])
if train_results
else None
),
"distribution_shift": shift,
"pii_stats": pii_stats,
}
def _group_by_session(self, logs):
"""Group logs by session_id."""
sessions = defaultdict(list)
for log in logs:
sessions[log.session_id].append(log)
return sessions
Flywheel Pipeline: Data Funnel
| Metric | Raw Logs | Post-PII | Post-Quality | Post-Dedup | Pairs Extracted | Pairs Validated |
|---|---|---|---|---|---|---|
| Daily volume (typical) |
Key Takeaways
The data flywheel converts production interactions into training signal. The pipeline has six stages: collection, PII scrubbing, quality filtering, preference extraction, reward model validation, and online DPO training.
The critical engineering decisions:
-
PII scrubbing is the bottleneck: Multi-stage scrubbing (regex + NER) catches 93-98% of PII. The remaining 2-7% requires conservative fallback strategies. GDPR right-to-erasure means every data point must be traceable to its source for deletion.
-
Implicit signals are noisy but abundant: Regeneration events produce 80% precision preference pairs. Copy events are lower precision (60%) but 5x more common. Combining multiple signal types with confidence weighting produces the best training data.
-
Quality filtering matters more than volume: Filtering production logs from 100% to 15% of volume (removing duplicates, short responses, repetitions, and low-diversity samples) improves downstream DPO win rate from 48% to 59%.
-
Online DPO adapts faster than batch retraining: Continuous training from a replay buffer with mixed online/offline data produces measurable improvements within 7 days. Batch retraining on a weekly schedule lags by 1-2 weeks.
-
Distribution shift is real and measurable: Production query distributions change over time as users discover new use cases and external events drive new topics. Monitoring Mahalanobis distance and MMD between training and production distributions provides early warning for model degradation.