大模型推理性能优化实战:从量化、KV Cache 到批处理调度的工程落地指南
做大模型应用,很多团队一开始关注的是“模型能不能跑”,但真正上线后,问题马上变成了另外三个字:跑不动。
常见现象很一致:
- 首 token 很慢,用户觉得“卡”
- 并发一上来,吞吐掉得厉害
- GPU 显存明明很大,却总是不够用
- 同样的模型,别人一张卡能顶你两张卡
这篇文章不讲太多空泛概念,而是从工程落地角度,把三件最影响推理性能的事情串起来:
- 量化:先把模型“瘦身”
- KV Cache:减少重复计算
- 批处理调度:把 GPU 真正喂饱
我会尽量按“为什么、怎么做、怎么验证、怎么排查”的顺序来写。你可以把它当作一份中级工程师可直接照着做的教程。
背景与问题
在大模型推理中,性能瓶颈通常不只来自模型参数量本身,更来自下面几个环节的组合:
- 显存占用过高
- 权重占显存
- KV Cache 随上下文长度增长
- 解码阶段低效
- 生成每个 token 都要做一次前向
- 序列越长,历史信息越多
- 请求调度不合理
- 小 batch 太多,GPU 空转
- 大 batch 又可能导致尾延迟变差
- 精度与速度难平衡
- 全精度效果稳,但成本高
- 低比特量化快,但可能损失质量
很多人会把优化理解成“换个推理框架就好了”,但真实情况是:优化通常是一个系统工程。
你需要同时看:
- 模型权重如何存
- Attention 历史如何复用
- 请求如何拼批
- 长上下文如何管理
- 指标如何度量
如果只做其中一个,往往收益有限;把它们配合起来,收益才明显。
前置知识
建议你对以下内容有基本了解:
- Transformer 的自注意力机制
- PyTorch 基本使用
- Hugging Face Transformers 的模型加载方式
- GPU 显存、吞吐、延迟的基本概念
如果你已经在线上跑过一个文本生成服务,那么这篇文章会更容易对上号。
环境准备
本文示例主要基于 Python,建议环境如下:
- Python 3.10+
- PyTorch 2.0+
- transformers 4.35+
- accelerate
- bitsandbytes(做 8bit/4bit 量化)
- CUDA 可用的 NVIDIA GPU
安装示例:
pip install torch transformers accelerate bitsandbytes
如果你没有 NVIDIA GPU,量化部分的完整收益可能看不出来,但代码结构仍然可以参考。
核心原理
这一部分先把三件核心武器讲透:量化、KV Cache、批处理调度。
1. 量化:用更少 bit 表示权重
默认情况下,模型权重常见是:
- FP32:每个参数 4 字节
- FP16 / BF16:每个参数 2 字节
而量化后可以进一步压缩:
- INT8:每个参数约 1 字节
- INT4:每个参数约 0.5 字节
这会直接带来两类收益:
- 降低显存占用
- 提升带宽效率,从而间接提升推理速度
但量化不是白送的,它通常会带来:
- 精度轻微下降
- 某些层不适合激进量化
- 不同硬件支持差异明显
量化的工程判断
如果你的主要目标是:
- 先跑起来,减少显存压力:优先试 8bit
- 极限压缩,追求更低成本:试 4bit,但一定做质量回归
- 质量要求特别高:保留部分关键层为 FP16/BF16
2. KV Cache:避免重复计算历史 token
大模型生成文本时,通常是自回归过程。
也就是说,生成第 t 个 token 时,要依赖前面 1...t-1 的历史信息。
如果每次都把整段历史重新算一遍,成本会非常高。
因此推理框架通常会缓存每一层 Attention 的 Key/Value,这就是 KV Cache。
它的直觉可以理解成:
- 不再重复做历史 token 的投影
- 只为“新 token”计算增量部分
没有 KV Cache 的情况
每生成一个 token,都重新做整段序列前向,复杂度会很夸张。
有 KV Cache 的情况
每生成一个 token 时:
- 历史 K/V 直接复用
- 只计算当前 token 的 Q、K、V
- 再与历史 K/V 拼起来做 attention
这会显著降低解码阶段开销。
3. 批处理调度:让 GPU 不饿着
很多线上服务慢,不是模型太差,而是调度太粗糙。
典型低效方式:
- 每个请求单独跑
- 请求来了就立刻执行
- 不区分 prefilling 和 decoding
- 长短请求混在一起,拖慢整体
更合理的思路是:
- 把多个请求组成 batch
- 在合适的时间窗口内拼批
- 对生成阶段做动态批处理
- 控制最大 batch token 数,而不是只看 batch size
为什么“token 数”比“请求数”更重要?
因为两个请求看起来都是 batch=2,但实际成本可能完全不同:
- 请求 A:输入 32 token,输出 16 token
- 请求 B:输入 2048 token,输出 512 token
如果只按请求个数调度,你很容易把系统打爆。
实际工程里,更稳定的指标是:
max_batch_sizemax_input_tokensmax_total_tokensmax_batch_total_tokens
一图看懂三种优化的关系
flowchart TD
A[用户请求进入推理服务] --> B[Tokenizer 编码]
B --> C[批处理调度器]
C --> D[Prefill 阶段]
D --> E[生成 KV Cache]
E --> F[Decode 阶段循环]
F --> G[复用 KV Cache]
G --> H[输出 token]
D --> I[量化权重加载]
I --> F
这张图里最关键的点是:
- 量化主要作用在模型权重与显存占用
- KV Cache主要作用在 decode 阶段
- 批处理调度作用在整个请求生命周期
KV Cache 在一次请求中的生命周期
sequenceDiagram
participant U as User
participant S as Scheduler
participant M as Model
participant C as KV Cache
U->>S: 发起生成请求
S->>M: 执行 prefill
M->>C: 写入历史 K/V
loop 每个新 token
S->>M: decode 当前 token
M->>C: 读取历史 K/V
M->>C: 追加新 K/V
M-->>S: 返回 next token
end
S-->>U: 返回完整结果
如果你在线上遇到“首 token 还行,但长文本生成越来越慢”,通常要重点看:
- cache 是否真的被复用
- cache 是否频繁搬移
- batch 中是否混入了特别长的序列
从工程角度理解性能瓶颈
可以把一次生成分成两个阶段:
Prefill 阶段
- 输入整段 prompt
- 并行度较高
- 更像大矩阵计算
- 对吞吐和显存都敏感
Decode 阶段
- 一次只生成一个 token
- 计算粒度更小
- 更容易被调度、内存访问、cache 管理影响
很多人只盯着模型 FLOPs,但线上性能的关键往往是:
- Prefill 的大吞吐
- Decode 的低延迟
- KV Cache 的空间管理
- 动态 batch 的稳定性
实战代码(可运行)
下面我们用 Hugging Face 做一个可运行示例,演示:
- 普通加载
- 8bit 量化加载
- 启用 KV Cache
- 简单批处理生成
- 基础性能测试
说明:示例使用小模型,方便你本地验证。实际生产可替换成更大模型。
示例一:基线推理与 8bit 量化
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "gpt2"
def load_model_baseline():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
return tokenizer, model, device
def load_model_8bit():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
load_in_8bit=True,
device_map="auto"
)
model.eval()
return tokenizer, model, "cuda" if torch.cuda.is_available() else "cpu"
def run_generate(tokenizer, model, device, prompts, max_new_tokens=50):
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True
)
if device != "cpu":
inputs = {k: v.to(device) for k, v in inputs.items()}
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False
)
end = time.perf_counter()
texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return texts, end - start
if __name__ == "__main__":
prompts = [
"请用简单的话解释什么是大模型推理优化。",
"什么是KV Cache,它为什么能提升生成速度?"
]
print("=== Baseline ===")
tokenizer, model, device = load_model_baseline()
texts, latency = run_generate(tokenizer, model, device, prompts)
print(f"Latency: {latency:.3f}s")
for i, t in enumerate(texts):
print(f"[{i}] {t}\n")
if torch.cuda.is_available():
print("=== 8bit Quantized ===")
tokenizer, model, device = load_model_8bit()
texts, latency = run_generate(tokenizer, model, device, prompts)
print(f"Latency: {latency:.3f}s")
for i, t in enumerate(texts):
print(f"[{i}] {t}\n")
你该关注什么?
这个例子不是为了比较 gpt2 的绝对性能,而是让你看到一条完整路径:
- 如何切换普通加载与量化加载
- 如何在生成时显式启用
use_cache=True - 如何做最基础的端到端耗时测试
示例二:查看显存占用与吞吐
我自己排查推理问题时,经常先写这种小脚本,因为比盲猜靠谱得多。
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "gpt2"
def gpu_mem_mb():
if not torch.cuda.is_available():
return 0
return torch.cuda.memory_allocated() / 1024 / 1024
def benchmark(batch_size=4, max_new_tokens=32):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
prompts = ["请解释批处理调度的作用。"] * batch_size
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
if device != "cpu":
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
inputs = {k: v.to(device) for k, v in inputs.items()}
start_mem = gpu_mem_mb()
start = time.perf_counter()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
use_cache=True,
do_sample=False
)
if device != "cpu":
torch.cuda.synchronize()
end = time.perf_counter()
end_mem = gpu_mem_mb()
total_tokens = outputs.shape[0] * outputs.shape[1]
elapsed = end - start
tps = total_tokens / elapsed
print(f"Batch size: {batch_size}")
print(f"Elapsed: {elapsed:.3f}s")
print(f"Throughput: {tps:.2f} tokens/s")
print(f"GPU mem delta: {end_mem - start_mem:.2f} MB")
if device != "cpu":
print(f"Peak GPU mem: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f} MB")
if __name__ == "__main__":
for bs in [1, 2, 4, 8]:
benchmark(batch_size=bs)
print("-" * 40)
如何解读结果?
你可以重点看三件事:
- batch size 增长后,
tokens/s是否提升 - 延迟是否恶化过快
- peak memory 是否逼近显存上限
如果吞吐没有变好,通常说明:
- 模型太小,GPU 本来就没吃满
- 请求太短,调度开销比计算还大
- batch 拼得不合理
示例三:一个简化版动态批处理调度器
生产环境通常会用 vLLM、TGI、TensorRT-LLM 等框架,但你理解一个“简化版调度器”会非常有帮助。下面这段代码模拟“按时间窗口收集请求,再统一生成”。
import time
import queue
import threading
from dataclasses import dataclass
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "gpt2"
@dataclass
class Request:
prompt: str
result: str = ""
done: bool = False
class DynamicBatchServer:
def __init__(self, model_name=MODEL_NAME, batch_wait_ms=50, max_batch_size=4):
self.batch_wait_ms = batch_wait_ms
self.max_batch_size = max_batch_size
self.q = queue.Queue()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
self.worker = threading.Thread(target=self._loop, daemon=True)
self.worker.start()
def submit(self, prompt: str) -> Request:
req = Request(prompt=prompt)
self.q.put(req)
return req
def _collect_batch(self) -> List[Request]:
batch = []
start = time.time()
while len(batch) < self.max_batch_size:
timeout = self.batch_wait_ms / 1000
remain = timeout - (time.time() - start)
if remain <= 0 and batch:
break
try:
req = self.q.get(timeout=max(remain, 0.001))
batch.append(req)
except queue.Empty:
break
return batch
def _loop(self):
while True:
batch = self._collect_batch()
if not batch:
continue
prompts = [r.prompt for r in batch]
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=32,
use_cache=True,
do_sample=False
)
texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
for req, text in zip(batch, texts):
req.result = text
req.done = True
if __name__ == "__main__":
server = DynamicBatchServer(batch_wait_ms=100, max_batch_size=4)
reqs = [
server.submit("请解释什么是模型量化。"),
server.submit("请解释什么是KV Cache。"),
server.submit("请解释为什么批处理可以提升吞吐。"),
]
while not all(r.done for r in reqs):
time.sleep(0.05)
for i, r in enumerate(reqs):
print(f"[{i}] {r.result}\n")
这段代码很简化,但有几个关键工程思想:
- 不立即执行,而是短暂等待拼批
- 限制最大 batch size
- 统一 tokenizer 和 generate 调用
- 请求对象持有结果状态
真正的线上系统当然会复杂很多,还会涉及:
- 流式输出
- cancel
- 超时
- 异构长度请求分桶
- prefill/decode 分离调度
- KV Cache 回收
但理解这里的最小版本,后面看成熟框架就会更顺。
批处理调度的设计思路
这部分给你一个更接近生产的思考框架。
stateDiagram-v2
[*] --> Waiting
Waiting --> Batching: 请求到达
Batching --> Prefill: 达到时间窗或批量上限
Prefill --> Decoding
Decoding --> Decoding: 继续生成
Decoding --> Finished: 所有请求完成
Finished --> Waiting
在真实服务里,调度器最怕两类情况:
情况一:只追求吞吐
结果是:
- batch 很大
- GPU 利用率很好看
- 但单请求尾延迟很差
- 用户体验糟糕
情况二:只追求低延迟
结果是:
- 看到请求就跑
- GPU 吃不满
- 单位成本很高
- 并发稍微上来就崩
因此调度参数一般需要平衡:
batch_wait_msmax_batch_sizemax_batch_total_tokensmax_new_tokens- 长短请求是否分桶
逐步验证清单
我建议你按下面顺序做验证,不要一口气把所有优化全开。否则一旦结果不对,很难定位。
第一步:建立基线
记录以下指标:
- 单请求首 token 延迟
- 单请求全量生成耗时
- tokens/s
- 峰值显存
- 输出质量样例
第二步:只开 KV Cache
检查:
- decode 阶段是否明显加速
- 长输出场景收益是否更明显
第三步:只开量化
检查:
- 显存是否下降
- 吞吐是否提升
- 输出质量是否可接受
第四步:引入动态批处理
检查:
- 并发提升后 tokens/s 是否改善
- P95/P99 延迟是否仍在 SLA 内
第五步:组合优化
最终确认:
- 整体吞吐收益
- 长上下文稳定性
- 显存碎片情况
- 服务是否容易抖动
常见坑与排查
这部分很重要,我踩过不少。
1. 开了 KV Cache,但速度几乎没变
可能原因
- 你的测试文本太短
- 模型太小,收益不明显
- 测的是 prefill,不是 decode
- 框架虽然传了
use_cache=True,但实际路径没走到
排查建议
- 用更长的输出,比如
max_new_tokens=256 - 比较“关闭 cache”和“开启 cache”的 decode 时间
- 检查模型配置里
config.use_cache
print(model.config.use_cache)
2. 量化后反而变慢
这事并不稀奇。
可能原因
- 量化内核与硬件不匹配
- 小模型下,量化收益抵不过额外开销
- CPU/GPU 数据搬移增加
- 某些层 fallback 到低效实现
排查建议
- 换更大的模型看差异
- 确认 CUDA、bitsandbytes 版本匹配
- 用 profiler 看热点是否仍在 matmul
- 比较显存下降是否真实发生
3. batch 一大就 OOM
可能原因
- 只控制了 batch size,没控制 token 总量
- prompt 长度差异过大
- KV Cache 累积过多
- 没有及时释放已结束请求的 cache
排查建议
把调度策略从“按请求数限制”改成“按总 token 限制”。
你可以估算一下 KV Cache 大致占用:
KV Cache 显存 ≈ 层数 × 2(K/V) × batch × seq_len × hidden_size × dtype字节数
这不是精确公式,但很适合做一阶估算。
4. 明明 GPU 利用率不低,用户还是觉得慢
可能原因
- 高利用率来自大 batch,但尾延迟被拖长
- 首 token 时间太高
- 流式输出没做好
- tokenizer 或后处理成为瓶颈
排查建议
不要只看 GPU 利用率,还要看:
- TTFT(Time To First Token)
- TPOT(Time Per Output Token)
- P50 / P95 / P99 延迟
- 非模型耗时占比
5. 长上下文场景性能突然崩掉
可能原因
- KV Cache 爆炸式增长
- RoPE 扩展或长上下文配置不合理
- 显存碎片增多
- 长短请求混跑导致 decode 效率下降
排查建议
- 做长度分桶
- 限制最大上下文长度
- 为超长请求单独队列
- 定期观察 cache 命中、回收与显存峰值
安全/性能最佳实践
这一节我尽量给“能执行的建议”。
1. 先定指标,再做优化
至少要定义:
- TTFT:首 token 延迟
- TPOT:每个输出 token 平均耗时
- Throughput:tokens/s
- Peak Memory:峰值显存
- Quality Regression:质量回归结果
否则你很容易出现“感觉变快了,但其实没有”的错觉。
2. 优先做低风险优化顺序
我通常建议这个顺序:
- 确认
use_cache=True - 启用 FP16/BF16
- 尝试 8bit 量化
- 引入动态批处理
- 再评估 4bit、分页 KV Cache、连续批处理等高级优化
原因很简单:越靠前的改动,收益通常稳定、风险更小。
3. 长短请求分流
不要让一个 8K prompt 的请求和几十 token 的短请求混在同一个队列里。
更好的做法是:
- 短请求队列:追求低延迟
- 长请求队列:追求稳定吞吐
- 超长上下文单独限流
这在真实业务里很有用,尤其是聊天、摘要、RAG 混合场景。
4. 控制的是 token,而不只是 batch size
一个非常实用的原则:
- batch size 是表象
- token 总量才是资源本体
调度器最好同时限制:
- batch 内请求数
- batch 总输入 token
- batch 总生成 token 预算
5. 做好质量回归,不要只看性能
量化尤其容易出现“指标很好,但业务方不认”的情况。
建议至少准备三类回归集:
- 通用问答
- 业务专有术语
- 长上下文推理/总结
如果 4bit 让业务准确率明显下降,那省下来的卡费不一定值得。
6. 处理好缓存生命周期
KV Cache 虽然能加速,但如果管理不好,就是显存炸弹。
建议:
- 请求结束立刻释放相关 cache
- 超时请求强制回收
- 长会话设置上限
- 监控 cache 占用比例与回收延迟
7. 对外暴露安全边界
推理服务不只是性能问题,也要防止资源被打穿。建议接口层做这些限制:
- 最大输入长度
- 最大输出长度
- 最大并发数
- 单租户速率限制
- 超时中断与取消生成
否则一个异常长 prompt,就可能把整机服务拖慢。
一个实用的参数调优建议
如果你要从 0 到 1 调优一个服务,我建议这样试:
| 目标 | 重点参数 | 建议起点 |
|---|---|---|
| 降显存 | load_in_8bit / load_in_4bit | 先 8bit |
| 降 decode 延迟 | use_cache=True | 必开 |
| 提高吞吐 | max_batch_size | 从 4 开始 |
| 控制尾延迟 | batch_wait_ms | 20~100ms 试验 |
| 防 OOM | max_batch_total_tokens | 按显存压测反推 |
| 稳定质量 | do_sample=False 做基准 | 先固定解码策略 |
这个表不是银弹,但很适合作为第一次压测的起点。
方案落地时的取舍建议
如果你正在做选型,下面这个经验判断比较实用:
场景一:小团队,先求稳上线
建议:
- 用成熟推理框架
- 开启 KV Cache
- 先上 8bit 量化
- 做简单动态批处理
- 优先监控 TTFT 和显存
场景二:成本压力大,模型较大
建议:
- 认真评估 4bit
- 对长上下文做严格限额
- 引入 token 级别调度
- 做多轮压测和质量回归
场景三:高并发在线服务
建议:
- 做 continuous batching
- 长短请求分桶
- 监控 P95/P99
- 对取消、超时、cache 回收做专门治理
总结
大模型推理优化,真正有效的通常不是某一个“黑科技开关”,而是三件事协同:
- 量化解决权重显存与带宽问题
- KV Cache解决 decode 重复计算问题
- 批处理调度解决 GPU 利用率与吞吐问题
如果你让我给一个最务实的落地顺序,我会建议:
- 先建立基线指标
- 确认 KV Cache 生效
- 上 8bit 量化看显存收益
- 引入简单动态批处理
- 再根据业务场景优化长上下文、尾延迟和 cache 生命周期
最后提醒一句:
优化不是单纯追求“最快”,而是在质量、成本、延迟、吞吐之间找到适合你业务的平衡点。
很多时候,真正好的方案不是最激进的那个,而是那个上线后一周内不需要天天救火的方案。