INT8 tensor cores on Ampere deliver 2x the throughput of FP16 tensor coresβbut only if both matrix operands are quantized to INT8 simultaneously. That is the constraint that makes W8A8 (INT8 weights, INT8 activations) fundamentally harder than W4A16 (INT4 weights, FP16 activations). Weights are staticβquantize once, serve forever. Activations are generated fresh for every token, and they contain outlier channels with magnitudes 100x larger than the median, making per-tensor INT8 activation quantization catastrophically lossy without SmoothQuantβs channel migration trick. Get the scaling wrong and you lose 3+ perplexity points. Get it right and you get true 2x tensor core speedup with near-lossless quality.
This post covers the complete W8A8 pipeline: offline weight quantization, online activation quantization with per-token scaling, the cuBLAS cublasLtMatmul INT8 API, the INT32 accumulation and FP16 dequantization math, SmoothQuant integration, and benchmarks comparing INT8 to FP8 on H100.
The INT8 GEMM Pipeline
The W8A8 GEMM computes where (activations) and (weights) are both quantized to INT8:
The INT8 matmul produces INT32 accumulations, which are then dequantized to FP16/BF16 using the scale factors:
Input: X_int8 (M x K), W_int8 (N x K)
Step 1: Y_int32 = X_int8 @ W_int8^T [INT8 tensor cores, INT32 accumulation]
Step 2: Y_fp16 = Y_int32 * (s_x * s_w^T) [Dequantize with scale factors]
Output: Y_fp16 (M x N)
The critical insight is that the scale factors are factored out of the matmul. The tensor cores operate on pure INT8 values with INT32 accumulation. Dequantization happens after the matmul.
import torch
import numpy as np
def w8a8_gemm_reference(X_fp, W_fp, x_scale_type='per_token', w_scale_type='per_channel'):
"""Reference W8A8 INT8 GEMM implementation.
Args:
X_fp: FP32 activations, shape (M, K)
W_fp: FP32 weights, shape (N, K)
x_scale_type: 'per_tensor' or 'per_token'
w_scale_type: 'per_tensor' or 'per_channel'
Returns:
Y_fp: dequantized output, shape (M, N)
"""
M, K = X_fp.shape
N = W_fp.shape[0]
qmax = 127
# Step 1: Quantize activations
if x_scale_type == 'per_tensor':
x_scale = X_fp.abs().max() / qmax
x_scale = max(x_scale, 1e-10)
X_q = torch.clamp(torch.round(X_fp / x_scale), -128, 127).to(torch.int8)
# x_scale shape: scalar
elif x_scale_type == 'per_token':
x_scale = X_fp.abs().amax(dim=1, keepdim=True) / qmax
x_scale = x_scale.clamp(min=1e-10)
X_q = torch.clamp(torch.round(X_fp / x_scale), -128, 127).to(torch.int8)
# x_scale shape: (M, 1)
# Step 2: Quantize weights (offline)
if w_scale_type == 'per_tensor':
w_scale = W_fp.abs().max() / qmax
w_scale = max(w_scale, 1e-10)
W_q = torch.clamp(torch.round(W_fp / w_scale), -128, 127).to(torch.int8)
# w_scale shape: scalar
elif w_scale_type == 'per_channel':
w_scale = W_fp.abs().amax(dim=1, keepdim=True) / qmax
w_scale = w_scale.clamp(min=1e-10)
W_q = torch.clamp(torch.round(W_fp / w_scale), -128, 127).to(torch.int8)
# w_scale shape: (N, 1)
# Step 3: INT8 GEMM with INT32 accumulation
Y_int32 = X_q.int() @ W_q.int().T # (M, N)
# Step 4: Dequantize
# Y_fp = Y_int32 * (x_scale * w_scale^T)
if x_scale_type == 'per_tensor' and w_scale_type == 'per_tensor':
Y_fp = Y_int32.float() * (x_scale * w_scale)
elif x_scale_type == 'per_token' and w_scale_type == 'per_channel':
# x_scale: (M, 1), w_scale: (N, 1) -> outer product: (M, N)
Y_fp = Y_int32.float() * (x_scale * w_scale.T)
elif x_scale_type == 'per_token' and w_scale_type == 'per_tensor':
Y_fp = Y_int32.float() * (x_scale * w_scale)
elif x_scale_type == 'per_tensor' and w_scale_type == 'per_channel':
Y_fp = Y_int32.float() * (x_scale * w_scale.T)
return Y_fp
# Test
torch.manual_seed(42)
M, N, K = 32, 4096, 4096
X = torch.randn(M, K) * 0.5
W = torch.randn(N, K) * 0.02
Y_ref = X @ W.T
Y_int8 = w8a8_gemm_reference(X, W, 'per_token', 'per_channel')
mse = ((Y_ref - Y_int8) ** 2).mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
Y_ref.flatten(), Y_int8.flatten(), dim=0
).item()
print(f"MSE: {mse:.6e}, Cosine sim: {cos_sim:.8f}")
Scale Factor Compatibility with INT8 GEMM
The scale factor strategy must be compatible with INT8 tensor core execution. The constraint is that the scale factor must be factorable out of the inner product:
This factoring works when:
- Per-tensor scaling: and are scalars, trivially factors out
- Per-token x Per-channel: depends only on row , depends only on row β factors out as an outer product
- Per-channel x Per-channel: and both depend on the inner dimension β DOES NOT factor out
def verify_scale_compatibility(x_scale_shape, w_scale_shape, M, N, K):
"""Check if scale factors are compatible with INT8 GEMM.
Compatible means the scales can be factored out of the inner sum.
"""
x_dims = set()
w_dims = set()
if x_scale_shape == 'scalar':
pass # No dimension dependency
elif x_scale_shape == 'per_token':
x_dims.add('M') # Depends on row index
elif x_scale_shape == 'per_channel':
x_dims.add('K') # Depends on inner dimension
if w_scale_shape == 'scalar':
pass
elif w_scale_shape == 'per_channel':
w_dims.add('N') # Depends on output channel
elif w_scale_shape == 'per_input_channel':
w_dims.add('K') # Depends on inner dimension
# Compatible if neither depends on K, or both depend on K with same scale
k_dep = 'K' in x_dims or 'K' in w_dims
compatible = not k_dep
return compatible
# Check all combinations
combinations = [
('scalar', 'scalar', True),
('scalar', 'per_channel', True),
('per_token', 'scalar', True),
('per_token', 'per_channel', True), # Standard W8A8
('per_channel', 'per_channel', False), # Incompatible!
('per_channel', 'per_input_channel', False),
]
for x_s, w_s, expected in combinations:
result = verify_scale_compatibility(x_s, w_s, 32, 4096, 4096)
status = "OK" if result == expected else "MISMATCH"
print(f" X={x_s:>15s}, W={w_s:>20s}: "
f"compatible={result} [{status}]")
The standard W8A8 configuration uses per-token activation scaling (one scale per token/row of X) and per-channel weight scaling (one scale per output channel/row of W). This gives both operands good quantization quality while remaining compatible with INT8 tensor core GEMM. The dequantization is an outer product of the two scale vectors, applied element-wise to the INT32 output.
cuBLAS INT8 GEMM API
NVIDIAβs cuBLAS library provides INT8 GEMM through the cublasLtMatmul API. The setup is more involved than FP16 GEMM because of the mixed-type accumulation and scaling:
#include <cublasLt.h>
// cuBLAS INT8 GEMM setup
void setup_int8_gemm(
cublasLtHandle_t handle,
int M, int N, int K,
const int8_t* A, // Activations (M x K)
const int8_t* B, // Weights (N x K, stored as K x N column-major)
float* C, // Output (M x N)
float alpha, // Global scale factor
float beta // Output accumulation factor
) {
cublasLtMatmulDesc_t matmulDesc;
cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F);
// Set transpose operations
cublasOperation_t transA = CUBLAS_OP_N;
cublasOperation_t transB = CUBLAS_OP_T;
cublasLtMatmulDescSetAttribute(
matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA,
&transA, sizeof(transA)
);
cublasLtMatmulDescSetAttribute(
matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB,
&transB, sizeof(transB)
);
// Create matrix layouts
cublasLtMatrixLayout_t layoutA, layoutB, layoutC;
// A: INT8, M x K
cublasLtMatrixLayoutCreate(&layoutA, CUDA_R_8I, M, K, M);
// B: INT8, K x N (column-major for N x K row-major)
cublasLtMatrixLayoutCreate(&layoutB, CUDA_R_8I, K, N, K);
// C: FP32, M x N
cublasLtMatrixLayoutCreate(&layoutC, CUDA_R_32F, M, N, M);
// Execute INT8 GEMM
// Y = alpha * (A_int8 @ B_int8^T) + beta * C
// Internally: INT8 x INT8 -> INT32 accumulation -> FP32 output
cublasLtMatmul(
handle, matmulDesc,
&alpha, A, layoutA,
B, layoutB,
&beta, C, layoutC,
C, layoutC,
NULL, NULL, 0, 0 // Workspace, preferences, stream
);
// Cleanup
cublasLtMatmulDescDestroy(matmulDesc);
cublasLtMatrixLayoutDestroy(layoutA);
cublasLtMatrixLayoutDestroy(layoutB);
cublasLtMatrixLayoutDestroy(layoutC);
}
Per-Token x Per-Channel Dequantization After cuBLAS
cuBLAS INT8 GEMM produces a single output with a global alpha scale. For per-token x per-channel scaling, we need a post-GEMM dequantization kernel:
// Post-GEMM dequantization kernel
// Y_fp = Y_int32 * x_scales[row] * w_scales[col]
__global__ void dequantize_int32_per_token_per_channel(
const int32_t* __restrict__ Y_int32, // (M, N)
const float* __restrict__ x_scales, // (M,)
const float* __restrict__ w_scales, // (N,)
half* __restrict__ Y_fp16, // (M, N)
int M, int N
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float val = (float)Y_int32[row * N + col];
val *= x_scales[row] * w_scales[col];
Y_fp16[row * N + col] = __float2half(val);
}
}
In practice, this dequantization is fused into the GEMM epilogue using cuBLAS epilogue functions or custom kernels.
Dynamic Activation Quantization
Weights are quantized offline, but activations must be quantized at runtime because their distribution depends on the input. The quantization kernel runs before each GEMM:
def dynamic_quantize_per_token(X_fp):
"""Quantize activations to INT8 per-token at runtime.
This runs on every forward pass. Must be fast.
Args:
X_fp: FP16/BF16 activations, shape (M, K)
Returns:
X_int8: quantized activations, shape (M, K)
scales: per-token scale factors, shape (M, 1)
"""
# Find per-token maximum
abs_max = X_fp.abs().amax(dim=-1, keepdim=True) # (M, 1)
# Compute scale
scales = abs_max / 127.0
scales = scales.clamp(min=1e-10)
# Quantize
X_int8 = (X_fp / scales).round().clamp(-128, 127).to(torch.int8)
return X_int8, scales
The runtime overhead of dynamic quantization is the cost of computing per-token max (a reduction) plus the division and round. On GPU, this takes approximately 10-15 microseconds for a typical (32, 4096) activation tensor β negligible compared to the GEMM itself.
def measure_quantization_overhead(M, K, num_iterations=1000):
"""Measure dynamic quantization kernel time."""
X = torch.randn(M, K, dtype=torch.float16, device='cuda')
# Warmup
for _ in range(10):
dynamic_quantize_per_token(X)
torch.cuda.synchronize()
import time
start = time.perf_counter()
for _ in range(num_iterations):
dynamic_quantize_per_token(X)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / num_iterations
return elapsed * 1e6 # microseconds
# Expected: ~12us for (32, 4096), ~18us for (128, 4096)
INT8 vs FP8: When INT8 Wins
H100 provides both INT8 and FP8 tensor core support, both at 1979 TFLOPS (dense). The question is when to prefer INT8 over FP8.
Precision Comparison
def compare_int8_fp8_precision():
"""Compare INT8 and FP8 E4M3 representable values."""
# INT8: 256 uniform levels in [-128, 127]
int8_levels = list(range(-128, 128))
# FP8 E4M3: non-uniform levels in [-448, 448]
# (simplified -- actual E4M3 has specific spacing)
fp8_levels = []
for sign in [1, -1]:
for exp in range(16): # 4-bit exponent
for mant in range(8): # 3-bit mantissa
if exp == 0:
val = sign * (mant / 8) * (2 ** (-6))
elif exp == 15 and mant == 7:
continue # NaN
elif exp == 15:
val = sign * (1 + mant / 8) * (2 ** 8)
else:
val = sign * (1 + mant / 8) * (2 ** (exp - 7))
fp8_levels.append(val)
return {
'int8_num_levels': len(int8_levels),
'int8_range': (min(int8_levels), max(int8_levels)),
'fp8_num_levels': len(set(fp8_levels)),
'fp8_range': (min(fp8_levels), max(fp8_levels)),
}
INT8 vs FP8 E4M3 Precision Characteristics
| Property | INT8 | FP8 E4M3 |
|---|---|---|
| Total levels | 256 | ~240 (excluding NaN) |
| Range | [-128, 127] | [-448, 448] |
| Dynamic range | 256:1 | ~3500:1 |
| Precision near 1.0 | 1.0 (uniform step) | 0.125 (mantissa) |
| Precision near 0.01 | 1.0 (same step) | 0.0005 (finer near 0) |
| Uniform spacing | Yes | No (logarithmic) |
| H100 throughput | 1979 TOPS | 1979 TFLOPS |
When INT8 Beats FP8
-
After SmoothQuant: Once outlier channels are smoothed, activation distributions are concentrated in a narrow range. INT8βs uniform spacing uses all 256 levels efficiently.
-
On Ampere (A100): A100 has INT8 tensor cores but no FP8 support. INT8 is the only sub-FP16 compute option.
-
Larger models with less sensitivity: Larger models (70B+) tend to be less sensitive to quantization, and INT8 provides sufficient precision.
-
When per-token scaling is sufficient: If the per-token activation range (after SmoothQuant) is narrow enough that 256 uniform levels suffice, INT8 avoids the complexity of FP8 calibration.
def should_use_int8(
gpu_generation,
model_size_B,
has_smoothquant,
activation_outlier_ratio,
):
"""Decide between INT8 and FP8 for W8A8 inference."""
if gpu_generation == 'ampere':
return True # No FP8 tensor cores on A100
if not has_smoothquant and activation_outlier_ratio > 20:
return False # FP8's wider range handles outliers better
if has_smoothquant:
# After SmoothQuant, INT8 and FP8 give similar quality
# INT8 has simpler calibration (no E4M3/E5M2 format selection)
return True
# Default: FP8 on Hopper for simplicity
return False
W8A8 Perplexity: INT8 vs FP8 on Llama-2 7B
(WikiText-2 Perplexity)Without SmoothQuant, FP8 is significantly better than INT8 for activations because its wider dynamic range accommodates outlier channels (5.56 vs 6.81 ppl). After SmoothQuant, the outliers are eliminated and INT8 matches FP8 quality (5.52 vs 5.49 ppl). If you use SmoothQuant, the choice between INT8 and FP8 comes down to hardware support and tooling, not quality.
Complete W8A8 INT8 Linear Layer
import torch
import torch.nn as nn
class W8A8Int8Linear(nn.Module):
"""W8A8 INT8 quantized linear layer for inference.
Weights: INT8 per-channel quantized (offline)
Activations: INT8 per-token quantized (dynamic)
GEMM: INT8 tensor cores with INT32 accumulation
Output: dequantized to FP16
"""
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
'weight_int8',
torch.zeros(out_features, in_features, dtype=torch.int8)
)
self.register_buffer(
'weight_scale',
torch.zeros(out_features, 1, dtype=torch.float32)
)
if bias:
self.register_buffer(
'bias', torch.zeros(out_features, dtype=torch.float16)
)
else:
self.bias = None
@classmethod
def from_float(cls, linear, smooth_scales=None):
"""Quantize a float linear layer to W8A8 INT8.
Args:
linear: FP16/FP32 nn.Linear
smooth_scales: optional SmoothQuant scales, shape (in_features,)
"""
in_f = linear.in_features
out_f = linear.out_features
layer = cls(in_f, out_f, bias=linear.bias is not None)
W = linear.weight.data.float() # (out_f, in_f)
# Apply SmoothQuant scaling to weights (if provided)
if smooth_scales is not None:
W = W * smooth_scales.unsqueeze(0)
# Per-channel quantization
w_max = W.abs().amax(dim=1, keepdim=True) # (out_f, 1)
w_scale = w_max / 127.0
w_scale = w_scale.clamp(min=1e-10)
W_q = (W / w_scale).round().clamp(-128, 127).to(torch.int8)
layer.weight_int8.copy_(W_q)
layer.weight_scale.copy_(w_scale)
if linear.bias is not None:
layer.bias.copy_(linear.bias.data.to(torch.float16))
return layer
def forward(self, x, smooth_scales=None):
"""Forward pass with dynamic activation quantization.
Args:
x: FP16 activations, shape (*, in_features)
smooth_scales: SmoothQuant scales for activation (fused into LN)
"""
original_shape = x.shape
x = x.reshape(-1, self.in_features).float()
# Dynamic per-token quantization of activations
x_max = x.abs().amax(dim=1, keepdim=True)
x_scale = x_max / 127.0
x_scale = x_scale.clamp(min=1e-10)
x_int8 = (x / x_scale).round().clamp(-128, 127).to(torch.int8)
# INT8 GEMM (simulated -- real impl uses cuBLAS INT8)
# Y_int32 = X_int8 @ W_int8^T
y_int32 = x_int8.int() @ self.weight_int8.int().T
# Dequantize: Y_fp = Y_int32 * (x_scale * w_scale^T)
y_fp = y_int32.float() * (x_scale * self.weight_scale.T)
if self.bias is not None:
y_fp = y_fp + self.bias.float()
y_fp = y_fp.half()
return y_fp.reshape(*original_shape[:-1], self.out_features)
# Verify
torch.manual_seed(42)
linear = nn.Linear(4096, 4096, bias=False)
nn.init.normal_(linear.weight, std=0.02)
int8_layer = W8A8Int8Linear.from_float(linear)
x = torch.randn(1, 32, 4096)
with torch.no_grad():
y_ref = linear(x)
y_int8 = int8_layer(x).float()
mse = ((y_ref - y_int8) ** 2).mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten(), y_int8.flatten(), dim=0
).item()
print(f"W8A8 INT8 MSE: {mse:.6e}")
print(f"Cosine similarity: {cos_sim:.8f}")
INT32 Accumulation Overflow Analysis
INT8 tensor cores accumulate into INT32. The maximum possible accumulation value depends on the matrix dimensions:
For :
INT32 range is . The worst-case accumulation uses only of the INT32 range. Overflow is not a practical concern for typical LLM dimensions.
def check_int32_overflow(K, abs_max_per_element=128):
"""Check if INT32 accumulation can overflow for given K."""
max_accumulation = K * abs_max_per_element * abs_max_per_element
int32_max = 2 ** 31 - 1
overflow_risk = max_accumulation > int32_max
utilization = max_accumulation / int32_max * 100
return {
'max_accumulation': max_accumulation,
'int32_max': int32_max,
'overflow_risk': overflow_risk,
'utilization_pct': utilization,
}
for K in [4096, 8192, 16384, 32768, 131072]:
result = check_int32_overflow(K)
risk = "OVERFLOW!" if result['overflow_risk'] else "safe"
print(f" K={K:>6d}: max_accum={result['max_accumulation']:>15,}, "
f"utilization={result['utilization_pct']:.1f}% [{risk}]")
K= 4096: max_accum= 67,108,864, utilization=3.1% [safe]
K= 8192: max_accum= 134,217,728, utilization=6.2% [safe]
K= 16384: max_accum= 268,435,456, utilization=12.5% [safe]
K= 32768: max_accum= 536,870,912, utilization=25.0% [safe]
K=131072: max_accum= 2,147,483,648, utilization=100.0% [OVERFLOW!]
For K dimensions above 131,072, worst-case INT32 accumulation can overflow. This is not a concern for typical LLM hidden dimensions (4096-16384) but matters for very long sequence attention computations where K = seq_len. In such cases, the GEMM must be tiled along K with intermediate FP32 accumulation.
SmoothQuant + INT8 End-to-End
The full W8A8 INT8 pipeline with SmoothQuant:
class SmoothQuantInt8Model:
"""Apply SmoothQuant and quantize a model to W8A8 INT8."""
def __init__(self, model, alpha=0.5):
self.model = model
self.alpha = alpha
def calibrate_and_quantize(self, calibration_dataloader, num_samples=128):
"""Full pipeline: calibrate, smooth, quantize."""
# Step 1: Collect activation statistics
act_maxes = {}
def make_hook(name):
def hook(module, input, output):
x = input[0].detach().float()
x_flat = x.reshape(-1, x.shape[-1])
batch_max = x_flat.abs().amax(dim=0)
if name not in act_maxes:
act_maxes[name] = batch_max
else:
act_maxes[name] = torch.max(act_maxes[name], batch_max)
return hook
hooks = []
for name, mod in self.model.named_modules():
if isinstance(mod, nn.Linear):
hooks.append(mod.register_forward_hook(make_hook(name)))
count = 0
self.model.eval()
with torch.no_grad():
for batch in calibration_dataloader:
if count >= num_samples:
break
self.model(batch['input_ids'].cuda())
count += batch['input_ids'].shape[0]
for h in hooks:
h.remove()
# Step 2: Compute SmoothQuant scales and apply
for name, mod in self.model.named_modules():
if isinstance(mod, nn.Linear) and name in act_maxes:
act_max = act_maxes[name].to(mod.weight.device)
weight_max = mod.weight.data.abs().amax(dim=0)
smooth_scale = (
act_max.pow(self.alpha) /
weight_max.clamp(min=1e-5).pow(1 - self.alpha)
).clamp(min=1e-5)
# Step 3: Quantize weights with smooth scaling applied
int8_mod = W8A8Int8Linear.from_float(
mod, smooth_scales=smooth_scale
)
# Store smooth scales for runtime activation division
int8_mod.register_buffer('smooth_scales', smooth_scale)
# Replace module in model
parent_name = '.'.join(name.split('.')[:-1])
child_name = name.split('.')[-1]
parent = dict(self.model.named_modules())[parent_name]
setattr(parent, child_name, int8_mod)
return self.model
Throughput Benchmarks
GEMM Throughput: cuBLAS FP16 vs INT8 (H100 SXM)
| Matrix (M x N x K) | FP16 TFLOPS | INT8 TOPS | INT8/FP16 Speedup |
|---|---|---|---|
| 1 x 4096 x 4096 | 0.034 | 0.059 | 1.7x |
| 32 x 4096 x 4096 | 0.89 | 1.71 | 1.9x |
| 128 x 4096 x 4096 | 3.21 | 6.18 | 1.9x |
| 512 x 4096 x 4096 | 11.2 | 21.4 | 1.9x |
| 2048 x 4096 x 4096 | 38.7 | 72.1 | 1.9x |
| 4096 x 4096 x 4096 | 68.2 | 128.6 | 1.9x |
End-to-End Decode: Llama-2 7B Tokens/sec (H100 SXM)
(Tokens per Second)Key insight: For single-token decode, W4A16 (980 tok/s) beats W8A8 (485 tok/s) because decode is bandwidth-bound and W4A16 loads half as much data. For large-batch inference, W8A8 INT8 wins because the GEMM is compute-bound and INT8 tensor cores provide 2x throughput.
Implementation in vLLM and TensorRT-LLM
# vLLM W8A8 INT8 integration
# Uses cutlass INT8 GEMM kernels with SmoothQuant
# Configuration for SmoothQuant in vLLM:
quantization_config = {
'method': 'smoothquant',
'weight_bits': 8,
'activation_bits': 8,
'alpha': 0.5,
'per_token_activation': True,
'per_channel_weight': True,
'calibration_dataset': 'c4',
'calibration_samples': 512,
}
# TensorRT-LLM W8A8 INT8:
# Uses cuBLAS INT8 GEMM with fused dequantization epilogue
# Supports per-tensor and per-channel weight scaling
# Activation quantization fused into preceding LayerNorm kernel
# Key difference: TensorRT-LLM fuses the activation quantization
# into the LayerNorm output, avoiding a separate kernel launch
# for dynamic quantization. This saves ~10us per layer.