Round-robin routing fails spectacularly for LLM inference. Request latencies vary 100x (50ms to 5000ms), and naive load balancing creates severe hot spots. Here’s how to route intelligently.

The Problem with Round-Robin

# Naive round-robin fails due to variable request sizes
requests = [
    {"prompt_len": 100, "max_tokens": 50},    # ~200ms
    {"prompt_len": 2000, "max_tokens": 500},  # ~5000ms
    {"prompt_len": 50, "max_tokens": 20},     # ~80ms
]

# Round-robin sends 1 request to each of 3 replicas
# Replica 0: 200ms
# Replica 1: 5000ms (25x longer!)
# Replica 2: 80ms

# New requests sent to replicas 0, 2 even though 1 is busy
# Result: Replica 1 builds massive queue, latency spikes

Latency Distribution with Round-Robin vs Smart Routing

(ms)
Round-Robin P50
450 ms
Round-Robin P99
8,500 ms
Smart Routing P50
320 ms
Smart Routing P99
1,200 ms

Least-Connections Routing

Track active connections per replica:

class LeastConnectionsRouter:
    def __init__(self, replicas: List[str]):
        self.replicas = replicas
        self.connections = {r: 0 for r in replicas}
        self.lock = threading.Lock()
    
    def route(self) -> str:
        with self.lock:
            replica = min(self.replicas, key=lambda r: self.connections[r])
            self.connections[replica] += 1
            return replica
    
    def release(self, replica: str):
        with self.lock:
            self.connections[replica] -= 1
ℹ️ Better, Not Perfect

Least-connections improves over round-robin but doesn’t account for request size. A replica with 10 short requests may be less loaded than one with 2 long requests.

Queue-Depth-Aware Routing

Account for estimated processing time:

class QueueDepthRouter:
    def __init__(self, replicas: List[str]):
        self.replicas = replicas
        self.queue_depths = {r: 0.0 for r in replicas}  # Estimated processing time
        self.lock = threading.Lock()
    
    def estimate_processing_time(self, request: dict) -> float:
        """Estimate request processing time in milliseconds."""
        prompt_len = request.get('prompt_len', 100)
        max_tokens = request.get('max_tokens', 100)
        
        # Empirical model: prefill + decode
        prefill_ms = prompt_len * 0.5  # 0.5ms per prompt token
        decode_ms = max_tokens * 20    # 20ms per output token (batch=1)
        
        return prefill_ms + decode_ms
    
    def route(self, request: dict) -> str:
        estimated_time = self.estimate_processing_time(request)
        
        with self.lock:
            # Find replica with lowest total queue depth
            replica = min(self.replicas, key=lambda r: self.queue_depths[r])
            self.queue_depths[replica] += estimated_time
            return replica
    
    def complete(self, replica: str, actual_time: float, estimated_time: float):
        with self.lock:
            # Subtract estimate, could also update estimation model
            self.queue_depths[replica] -= estimated_time
            self.queue_depths[replica] = max(0, self.queue_depths[replica])

Weighted Routing with Feedback

Adjust weights based on actual latencies:

class AdaptiveWeightedRouter:
    def __init__(self, replicas: List[str]):
        self.replicas = replicas
        self.weights = {r: 1.0 for r in replicas}
        self.latencies = {r: deque(maxlen=100) for r in replicas}  # Recent latencies
        self.ewma_latency = {r: 100.0 for r in replicas}  # Exponential moving average
        self.alpha = 0.1  # EWMA smoothing factor
    
    def route(self) -> str:
        # Weighted random selection (inversely proportional to latency)
        total_weight = sum(self.weights.values())
        r = random.random() * total_weight
        
        cumsum = 0
        for replica, weight in self.weights.items():
            cumsum += weight
            if r <= cumsum:
                return replica
        
        return self.replicas[-1]
    
    def record_latency(self, replica: str, latency_ms: float):
        # Update EWMA
        self.ewma_latency[replica] = (
            self.alpha * latency_ms + 
            (1 - self.alpha) * self.ewma_latency[replica]
        )
        
        # Update weights (inverse of latency)
        min_latency = min(self.ewma_latency.values())
        for r in self.replicas:
            self.weights[r] = min_latency / self.ewma_latency[r]
📊

Routing Strategy Comparison (8 replicas, mixed load)

StrategyP50 LatencyP99 LatencyThroughput
Round-Robin 450ms 8,500ms 180 req/s
Least-Connections 380ms 4,200ms 195 req/s
Queue-Depth 320ms 1,400ms 210 req/s
Adaptive Weighted 310ms 1,200ms 215 req/s
Note: Llama-70B, 8× A100, variable request sizes

Prefix Cache Routing

For KV cache reuse, route similar requests to same replica:

class PrefixCacheAwareRouter:
    def __init__(self, replicas: List[str]):
        self.replicas = replicas
        self.prefix_cache = {}  # prefix_hash -> replica
        self.cache_size = 10000
    
    def route(self, request: dict) -> str:
        prompt = request.get('prompt', '')
        
        # Check for prefix match (first 256 tokens)
        prefix = prompt[:1024]  # ~256 tokens
        prefix_hash = hash(prefix)
        
        if prefix_hash in self.prefix_cache:
            replica = self.prefix_cache[prefix_hash]
            if self._is_healthy(replica):
                return replica
        
        # No cache hit - use least connections
        replica = self._least_connections_route()
        
        # Cache the prefix -> replica mapping
        if len(self.prefix_cache) >= self.cache_size:
            # Evict random entry
            self.prefix_cache.pop(next(iter(self.prefix_cache)))
        self.prefix_cache[prefix_hash] = replica
        
        return replica
Cache Hit Benefits

Prefix cache routing can improve throughput by 20-40% for workloads with common system prompts or repeated queries, as KV cache is reused.

Health-Aware Routing

Remove unhealthy replicas from rotation:

class HealthAwareRouter:
    def __init__(self, replicas: List[str]):
        self.replicas = replicas
        self.health_status = {r: True for r in replicas}
        self.failure_counts = {r: 0 for r in replicas}
        self.last_check = {r: 0 for r in replicas}
        
        self.failure_threshold = 3
        self.recovery_interval = 30  # seconds
    
    def route(self) -> str:
        healthy = [r for r in self.replicas if self.health_status[r]]
        
        if not healthy:
            # All unhealthy - try oldest failed one
            return min(self.replicas, key=lambda r: self.last_check[r])
        
        return self._weighted_route(healthy)
    
    def record_result(self, replica: str, success: bool):
        if success:
            self.failure_counts[replica] = 0
            self.health_status[replica] = True
        else:
            self.failure_counts[replica] += 1
            if self.failure_counts[replica] >= self.failure_threshold:
                self.health_status[replica] = False
                self.last_check[replica] = time.time()
    
    def check_recovery(self):
        """Periodically re-enable failed replicas for health check."""
        now = time.time()
        for replica in self.replicas:
            if not self.health_status[replica]:
                if now - self.last_check[replica] > self.recovery_interval:
                    self.health_status[replica] = True  # Try again

Conclusion

Effective LLM request routing requires:

  1. Queue-depth awareness: Account for in-flight request costs
  2. Latency feedback: Adapt weights based on actual performance
  3. Prefix affinity: Route similar prompts together for cache hits
  4. Health monitoring: Remove failing replicas quickly

The combination of queue-depth routing + prefix affinity typically achieves 30-50% latency improvement over naive round-robin.