大模型推理优化实战:从 KV Cache、量化到批处理吞吐提升的工程方法
做大模型应用时,很多团队第一版系统都能“跑起来”,但很快就会遇到同一类问题:
- 首 token 太慢
- 并发一上来就炸显存
- 吞吐上不去,机器成本居高不下
- 量化后速度没明显提升,反而精度还掉了
- 批处理一开,尾延迟变得不可控
这些问题背后,其实不是某一个“魔法参数”没调好,而是推理链路里多个环节共同决定的:Prefill / Decode 阶段特性、KV Cache 占用、算子精度、批处理策略、调度方式、显存带宽瓶颈。
这篇文章我会用一个更偏工程落地的角度,把几种最常用、最有效的优化方法串起来:
- 理解推理瓶颈到底在哪里
- 用 KV Cache 降低重复计算
- 用量化降低显存与带宽压力
- 用动态批处理提升整体吞吐
- 建立一套可验证、可排查、可上线的优化流程
如果你已经会用 Transformers 或 vLLM 跑模型,这篇内容会帮助你从“能用”走到“用得更省、更稳、更快”。
背景与问题
先统一一个现实认知:大模型推理优化,很多时候不是算力不够,而是显存和带宽不够。
以自回归生成模型为例,每次生成一个新 token,都要读取历史上下文对应的 Key/Value。上下文越长,读取越多。结果就是:
- Prefill 阶段:一次性处理整段输入,计算密集
- Decode 阶段:每次只生成少量 token,但频繁访问 KV Cache,往往更受内存带宽限制
这就导致一个常见现象:
- GPU 利用率看起来不低
- 但 tokens/s 还是上不去
- 一加并发,显存迅速爆掉
一个典型线上症状
假设你部署了一个 7B 模型:
- 输入长度:2048
- 输出长度:256
- 并发请求:16
你可能会看到:
- 第一轮响应时间(TTFT)很高
- 生成阶段单 token 延迟波动明显
- 同一卡上不同请求互相抢占资源
- 长上下文请求把短请求拖慢
这时如果只盯着 nvidia-smi 看显存和利用率,通常是不够的。真正要拆的是:
- 模型权重占多少
- KV Cache 占多少
- 每次 batch 的 token 结构是什么
- 请求混合后,prefill 和 decode 是否互相干扰
- 当前是 compute-bound 还是 memory-bound
前置知识
阅读本文前,建议你至少熟悉这些概念:
- Transformer 自注意力机制
- 自回归生成
- GPU 显存与带宽的基本区别
- FP16 / BF16 / INT8 / INT4 的基本含义
- Python 基础,能运行 Hugging Face 相关代码
不要求你自己写 CUDA,但需要知道:很多优化不是改模型结构,而是改“怎么喂给 GPU”。
环境准备
下面的示例以 Python 为主,建议准备:
- Python 3.10+
- PyTorch 2.2+
- transformers 4.44+
- accelerate
- bitsandbytes
- vllm(用于高吞吐部署实验)
- 一张支持 CUDA 的 NVIDIA GPU
安装示例:
pip install torch transformers accelerate bitsandbytes vllm
如果你在本地卡上实验,优先选一个体量适中的指令模型,避免环境问题掩盖核心结论。
核心原理
这一节我们先把三件事讲透:KV Cache、量化、批处理。理解它们之间的关系,比记住几个命令更重要。
1. 推理阶段拆解:Prefill 与 Decode
自回归生成大致分两段:
-
Prefill
- 把整段输入一次性过模型
- 生成第一批 attention 所需的 K/V
- 通常计算量大,但并行度高
-
Decode
- 每次基于已有上下文再生成一个 token
- 使用缓存的 K/V,避免重复计算全部历史 token
- 每步计算小,但访问缓存频繁
可以把它理解成:
- Prefill 像“把书先翻一遍做索引”
- Decode 像“后面每问一个问题,就拿着索引快速查”
如果没有 KV Cache,decode 时每个新 token 都要把整个历史重新算一遍,复杂度会急剧变差。
flowchart LR
A[输入 Prompt] --> B[Prefill: 全量计算注意力]
B --> C[生成 KV Cache]
C --> D[Decode Step 1]
D --> E[复用历史 KV]
E --> F[Decode Step 2]
F --> G[持续生成直到结束]
2. KV Cache 为什么重要
在 Transformer 的每一层里,历史 token 会形成对应的 Key 和 Value。生成下一个 token 时,Query 只需要与历史 K/V 做注意力计算。
所以 KV Cache 的本质是:
- 用显存换算力
- 避免重复计算历史上下文的 K/V
- 极大降低 decode 的重复开销
但它的副作用也很明显:
- 上下文越长,KV Cache 越大
- batch 越大,KV Cache 成比例增长
- 长对话、多轮会话、流式输出会持续占用显存
一个常用的近似估算方式:
KV Cache 大小 ≈ 2 × 层数 × hidden_size × 序列长度 × batch_size × bytes_per_element
这里的 2 表示 K 和 V 两份缓存。
工程上你会发现: 模型权重不一定是最大头,长上下文 + 多并发时,KV Cache 常常才是显存杀手。
3. 量化为什么有时“省了显存,但没快多少”
很多同学第一次上量化,预期是:
- 显存减半
- 吞吐翻倍
现实通常没这么理想。原因在于:
- 权重量化主要减少模型参数存储与访存
- KV Cache 不一定同步量化
- decode 阶段往往是内存带宽瓶颈,不完全是算术瓶颈
- 某些量化实现存在反量化开销
- 某些硬件对低比特算子支持不充分
所以量化的收益要分开看:
- 权重量化:降低模型加载显存、提升可部署模型尺寸
- 激活/Cache 量化:进一步压缩运行中内存
- 端到端速度提升:取决于框架、算子融合、硬件支持、batch 结构
4. 批处理为什么能提吞吐,但也会拉高延迟
批处理的核心思想是:把多个请求凑成一个 batch,让 GPU 吃得更饱。
好处:
- 提高设备利用率
- 降低单请求摊销成本
- 提升整体 tokens/s
代价:
- 请求需要等待凑批
- 长短请求混合会产生 padding 或调度不均
- 尾延迟可能上升
工程里最常用的不是“固定 batch”,而是:
- 动态批处理
- 按 token 数而不是按请求数控制 batch
- 将 prefill 与 decode 分开调度
sequenceDiagram
participant U1 as 请求1
participant U2 as 请求2
participant S as 调度器
participant G as GPU
U1->>S: 到达
U2->>S: 到达
S->>G: 组成动态批次进行 Prefill
G-->>S: 返回首 token 所需状态
S->>G: 组成 Decode 批次
G-->>S: 连续生成 token
S-->>U1: 返回结果流
S-->>U2: 返回结果流
从工程视角看三类瓶颈
在真正优化前,我建议先判断自己主要卡在哪一类。
1. 显存瓶颈
表现:
- 模型刚加载就接近满显存
- 并发稍高就 OOM
- 上下文一变长就崩
优先手段:
- 权重量化
- 限制最大上下文长度
- 限制同时活跃 session
- 使用支持更高效 KV 管理的推理框架
2. 带宽瓶颈
表现:
- decode 阶段 tokens/s 很低
- GPU 算力没有完全打满
- 长上下文性能恶化明显
优先手段:
- 降低 KV Cache 占用
- 使用 paged attention / 更高效 cache 管理
- 提升 batch 结构质量
- 尽量减少无意义长上下文
3. 调度瓶颈
表现:
- 单请求性能还行,但一到线上就抖
- 吞吐不稳定
- 尾延迟非常难看
优先手段:
- 动态批处理
- prefill / decode 分离
- 请求分类队列
- 限流、超时和最大输出长度控制
实战代码(可运行)
这一节我给出两个层次的示例:
- 用 Transformers 演示 KV Cache 与量化的基本用法
- 用 vLLM 演示更贴近生产的高吞吐推理
说明:代码可运行,但你需要根据本机显卡和模型权限,替换成合适模型。
实战一:用 Transformers 验证 KV Cache 的效果
示例目标
- 对比
use_cache=True/False - 测试生成耗时
- 观察长上下文下的差异
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True
)
model.eval()
prompt = "请用通俗语言解释什么是 KV Cache,并举一个生活中的类比。" * 20
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
def benchmark(use_cache: bool, max_new_tokens: int = 64, warmup: int = 1, runs: int = 3):
for _ in range(warmup):
_ = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=use_cache
)
if device == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(runs):
start = time.time()
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=use_cache
)
if device == "cuda":
torch.cuda.synchronize()
times.append(time.time() - start)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sum(times) / len(times), text
avg_cache, _ = benchmark(True)
avg_no_cache, _ = benchmark(False)
print(f"use_cache=True 平均耗时: {avg_cache:.3f}s")
print(f"use_cache=False 平均耗时: {avg_no_cache:.3f}s")
print(f"加速比: {avg_no_cache / avg_cache:.2f}x")
你应该关注什么
不要只看“总耗时”,最好再看:
- 输入 token 数
- 输出 token 数
- TTFT
- 平均 decode token latency
如果只是粗暴比较总耗时,容易把 prefill 和 decode 混在一起,结论会失真。
实战二:4bit 量化加载模型
如果你的 GPU 显存比较紧,4bit 是很常见的第一步。下面用 bitsandbytes 做一个简单示例。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
prompt = "请总结量化推理的优点、代价,以及适用场景。"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
use_cache=True
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
这段代码的实际意义
它解决的是:
- 模型装得下
- 单卡可部署更大模型
- 显存留更多给 KV Cache 和 batch
它不保证一定“显著更快”。如果你上线前要做选择,我建议至少对比这三组:
- FP16/BF16
- INT8
- INT4
并分别观察:
- 首 token 延迟
- decode tokens/s
- 显存占用
- 任务精度损失
实战三:计算 KV Cache 的近似占用
很多问题不是“感觉显存不够”,而是没有提前估算。下面给一个简单估算函数。
def estimate_kv_cache_gb(
batch_size: int,
seq_len: int,
num_layers: int,
hidden_size: int,
bytes_per_elem: int = 2
) -> float:
total_bytes = 2 * batch_size * seq_len * num_layers * hidden_size * bytes_per_elem
return total_bytes / (1024 ** 3)
examples = [
{"batch_size": 1, "seq_len": 4096, "num_layers": 32, "hidden_size": 4096},
{"batch_size": 8, "seq_len": 4096, "num_layers": 32, "hidden_size": 4096},
{"batch_size": 16, "seq_len": 8192, "num_layers": 32, "hidden_size": 4096},
]
for x in examples:
gb = estimate_kv_cache_gb(**x)
print(f"{x} => 约 {gb:.2f} GB")
怎么用这个估算结果
当你准备把并发从 8 拉到 32 时,先别直接上线。先问自己:
- 权重已经占多少显存?
- 还要预留多少给激活、碎片、框架开销?
- 峰值上下文长度到底是多少?
- 多轮会话会不会让 cache 长时间不释放?
我自己做容量评估时,通常不会把显存算到 95% 才放心,而是会留出更保守的安全边界。
实战四:用 vLLM 进行高吞吐推理
如果你要更贴近生产场景,vLLM 是非常值得试的。它的价值不只是“换个 API”,而是它在 KV 管理、调度、动态批处理 方面做了很多工程优化。
启动服务
python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen2-7B-Instruct \
--dtype float16 \
--max-model-len 4096 \
--gpu-memory-utilization 0.9
用 OpenAI 兼容接口调用
from openai import OpenAI
client = OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8000/v1"
)
resp = client.chat.completions.create(
model="Qwen/Qwen2-7B-Instruct",
messages=[
{"role": "system", "content": "你是一个简洁的技术助手。"},
{"role": "user", "content": "解释 KV Cache 如何影响大模型推理吞吐。"}
],
temperature=0.0,
max_tokens=128
)
print(resp.choices[0].message.content)
压测脚本示例
下面给一个简单并发压测脚本,用来粗看吞吐。
import time
import threading
from openai import OpenAI
client = OpenAI(
api_key="EMPTY",
base_url="http://127.0.0.1:8000/v1"
)
N = 16
results = []
def worker(i):
start = time.time()
resp = client.chat.completions.create(
model="Qwen/Qwen2-7B-Instruct",
messages=[
{"role": "user", "content": f"请用 100 字解释批处理对推理吞吐的影响。请求编号{i}"}
],
temperature=0.0,
max_tokens=128
)
elapsed = time.time() - start
results.append((i, elapsed, len(resp.choices[0].message.content)))
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
start_all = time.time()
for t in threads:
t.start()
for t in threads:
t.join()
total = time.time() - start_all
for item in sorted(results):
print(f"req={item[0]}, latency={item[1]:.3f}s, chars={item[2]}")
print(f"total time={total:.3f}s, qps≈{N/total:.2f}")
为什么这比单纯 Transformers 更适合服务化
因为服务化问题的关键不只是“生成”,而是:
- 多请求如何共享 GPU
- 不同长度请求如何调度
- KV Cache 如何分页管理
- 如何降低内存碎片
- 如何让吞吐和延迟达到可接受平衡
flowchart TD
A[请求进入] --> B[按到达时间进入队列]
B --> C[动态批处理]
C --> D[Prefill 调度]
D --> E[KV Cache 分配/分页管理]
E --> F[Decode 调度]
F --> G[流式返回]
E --> H[请求结束后回收缓存]
逐步验证清单
我建议不要一上来就做“大而全优化”,而是按下面顺序逐步验证。
第一步:建立基线
记录最基础指标:
- 模型版本
- 精度类型
- 最大上下文
- 单请求 TTFT
- 单请求 tokens/s
- 并发 1/4/8/16 下吞吐与 P95 延迟
- 显存占用峰值
第二步:只开 KV Cache
验证:
- decode latency 是否明显下降
- 显存上升是否可接受
第三步:只开量化
验证:
- 模型是否更容易装下
- 精度退化是否在业务容忍范围内
- decode 吞吐是否真的提升,而不是“只有显存下降”
第四步:引入动态批处理
验证:
- 总吞吐是否上升
- P95/P99 是否恶化
- 长短请求混合时是否出现明显抖动
第五步:组合优化
验证组合效果,而不是单项最佳:
- KV Cache + 量化
- KV Cache + 动态批处理
- 量化 + 更长上下文
- 全部叠加后的稳定性
常见坑与排查
这部分非常重要。很多优化失败,不是原理不对,而是踩了实现细节。
坑 1:量化后显存降了,但速度没提升
现象:
- 4bit 能跑更大模型
- 但 tokens/s 并没有明显变快
原因可能是:
- 当前瓶颈在 KV Cache 或内存带宽,不在权重访存
- 低比特算子没有被硬件高效支持
- 反量化开销抵消了一部分收益
- batch 太小,吞吐提升不明显
排查建议:
- 对比不同 batch、不同上下文长度下的性能
- 记录 prefill 和 decode 各自耗时
- 检查框架是否真的启用了高效量化内核
坑 2:开了 batch,P95 延迟变差很多
现象:
- 总体 tokens/s 上升
- 但用户抱怨“有时特别慢”
常见原因:
- 凑批等待时间过长
- 长请求拖住短请求
- prefill 与 decode 混跑互相干扰
排查建议:
- 给请求按输入长度分桶
- 限制单批最大 token 数
- 将极长请求放到单独队列
- 监控队列等待时间,而不是只看模型计算时间
坑 3:明明设置了最大上下文,还是 OOM
现象:
- 理论上显存够
- 实际运行却偶发 OOM
原因可能是:
- 多轮会话缓存未及时释放
- CUDA 内存碎片
- 框架额外缓存和中间张量
- 突发并发使活跃 token 数超预期
排查建议:
- 监控活跃 session 数与活跃 token 数
- 主动设置更保守的
max_model_len - 定期观察
torch.cuda.memory_summary() - 对长会话做超时回收或截断策略
坑 4:首 token 很慢,但后续输出还行
现象:
- TTFT 高
- decode token latency 正常
通常说明:
- prefill 成本太高
- 输入 prompt 太长
- 检索拼接上下文过多
- 系统 prompt 设计臃肿
解决思路:
- 精简系统提示词
- 对 RAG 上下文做截断与重排
- 缓存公共前缀 prompt
- 复用 prefix cache(如果框架支持)
坑 5:同样的模型,离线压测和线上表现差很多
原因常见于:
- 线上请求长度分布更极端
- 用户输出长度不可控
- 多租户流量导致负载抖动
- 流式输出、网络开销、日志采样影响端到端延迟
建议:
- 用真实流量分布回放
- 不要只做“固定 512 输入 + 128 输出”的理想压测
- 线上线下统一指标口径
stateDiagram-v2
[*] --> 基线测试
基线测试 --> 单项优化验证
单项优化验证 --> 组合优化
组合优化 --> 压测观察
压测观察 --> 上线灰度
上线灰度 --> 稳定运行
压测观察 --> 回退排查
上线灰度 --> 回退排查
回退排查 --> 单项优化验证
安全/性能最佳实践
这一节我把比较实用的建议收敛成清单,适合上线前逐项核对。
1. 对输出长度设置硬上限
这是最简单、最有效的稳态控制手段之一。
建议控制:
max_new_tokens- 每租户并发上限
- 总活跃请求数
- 总活跃 token 数
否则某些“超长输出”请求会把整个服务拖垮。
2. 对输入长度做分层治理
不要让所有请求都走同一条路径。可以分为:
- 短请求:实时交互优先
- 中等请求:正常动态批处理
- 超长请求:低优先级或单独队列
这样可以显著减少长尾延迟污染。
3. 优先监控 token 级指标,而不是只看 QPS
大模型服务里,QPS 很容易误导。真正更有效的是:
- TTFT
- TPOT(time per output token)
- 输入 token 总量
- 输出 token 总量
- 活跃 KV token 数
- P50/P95/P99 延迟
- OOM 次数
- 请求取消率
4. 为量化保留回退方案
量化收益很大,但并不总是稳定。上线时最好保留:
- FP16/BF16 版本
- INT8 版本
- INT4 版本
当你发现:
- 特定任务精度下降明显
- 特定 GPU 上性能异常
- 某版本驱动/内核不兼容
就能快速切回。
5. 谨慎对待“极限显存利用率”
我个人不建议把 gpu-memory-utilization 或显存占用打得太满。因为线上环境会有:
- 流量抖动
- 请求长度抖动
- 内存碎片
- 框架缓存波动
保守一些,往往比“纸面吞吐更高”更值钱。
6. 把优化目标拆成两个,而不是一个
常见误区是只追一个数字,比如“tokens/s 越高越好”。
实际上更合理的是同时设两个目标:
- 用户体验目标:TTFT、P95 latency
- 资源效率目标:总 tokens/s、单卡成本
因为很多优化能提升吞吐,但会明显伤害交互体验。你的系统是 API 服务还是离线批处理,取舍会不同。
7. 对公共前缀做缓存
如果你的业务有稳定的系统提示词、模板化上下文、固定工具定义,那么前缀缓存非常值得做。
适用场景:
- 统一 system prompt
- 多轮对话带相同角色设定
- RAG 前有固定指令模板
- 代码助手中常见规则模板
这类优化对 TTFT 往往很有帮助。
一个可落地的优化决策顺序
如果你问我:“工程上最实用的顺序是什么?”我通常会这么做:
-
先测基线
- 不然你根本不知道优化值不值
-
先上 KV Cache
- 这是 decode 提速的基本盘
-
再做权重量化
- 优先解决“装不下”和“显存太紧”的问题
-
再做动态批处理
- 重点提升整体吞吐
-
最后调队列、限流、请求分桶
- 这是把线上波动压下来的关键
也就是说,先解决“能稳定跑”,再解决“跑得更省”。很多团队反过来,一开始就盯着极限吞吐,最后反而被稳定性拖住。
总结
把大模型推理优化落到工程上,可以抓住一句话:
KV Cache 解决重复计算,量化解决显存压力,批处理解决吞吐效率,而真正的上线效果取决于调度与边界控制。
你可以把本文的重点记成这几条:
- KV Cache 是 decode 提速核心,但会显著占用显存
- 量化首先解决“装得下”,其次才是“跑更快”
- 批处理能提升吞吐,但不加约束会伤害尾延迟
- 线上优化要按 token 和阶段拆指标,不要只看总耗时
- prefill、decode、KV 管理、调度策略必须一起看
如果你正在做第一版优化,我建议从下面这份最小闭环开始:
- 建立单请求与并发基线
- 开启 KV Cache,对比 decode 性能
- 尝试 4bit/8bit 量化,对比显存与精度
- 用 vLLM 做动态批处理压测
- 加上最大输入、最大输出、活跃 token 上限
- 用真实流量分布复测 P95/P99
这样走下来,你基本就能从“模型能跑”提升到“服务可用、成本可控”。
如果只给一句最终建议,那就是:
别把推理优化当成单点技术,而要把它当成“显存、带宽、调度、请求分布”共同作用的系统工程。