跳转到内容
123xiao | 无名键客

《大模型推理优化实战:从 KV Cache、量化到批处理吞吐提升的工程方法》

字数: 0 阅读时长: 1 分钟

大模型推理优化实战:从 KV Cache、量化到批处理吞吐提升的工程方法

做大模型应用时,很多团队第一版系统都能“跑起来”,但很快就会遇到同一类问题:

  • 首 token 太慢
  • 并发一上来就炸显存
  • 吞吐上不去,机器成本居高不下
  • 量化后速度没明显提升,反而精度还掉了
  • 批处理一开,尾延迟变得不可控

这些问题背后,其实不是某一个“魔法参数”没调好,而是推理链路里多个环节共同决定的:Prefill / Decode 阶段特性、KV Cache 占用、算子精度、批处理策略、调度方式、显存带宽瓶颈

这篇文章我会用一个更偏工程落地的角度,把几种最常用、最有效的优化方法串起来:

  1. 理解推理瓶颈到底在哪里
  2. 用 KV Cache 降低重复计算
  3. 用量化降低显存与带宽压力
  4. 用动态批处理提升整体吞吐
  5. 建立一套可验证、可排查、可上线的优化流程

如果你已经会用 Transformers 或 vLLM 跑模型,这篇内容会帮助你从“能用”走到“用得更省、更稳、更快”。


背景与问题

先统一一个现实认知:大模型推理优化,很多时候不是算力不够,而是显存和带宽不够

以自回归生成模型为例,每次生成一个新 token,都要读取历史上下文对应的 Key/Value。上下文越长,读取越多。结果就是:

  • Prefill 阶段:一次性处理整段输入,计算密集
  • Decode 阶段:每次只生成少量 token,但频繁访问 KV Cache,往往更受内存带宽限制

这就导致一个常见现象:

  • GPU 利用率看起来不低
  • 但 tokens/s 还是上不去
  • 一加并发,显存迅速爆掉

一个典型线上症状

假设你部署了一个 7B 模型:

  • 输入长度:2048
  • 输出长度:256
  • 并发请求:16

你可能会看到:

  • 第一轮响应时间(TTFT)很高
  • 生成阶段单 token 延迟波动明显
  • 同一卡上不同请求互相抢占资源
  • 长上下文请求把短请求拖慢

这时如果只盯着 nvidia-smi 看显存和利用率,通常是不够的。真正要拆的是:

  • 模型权重占多少
  • KV Cache 占多少
  • 每次 batch 的 token 结构是什么
  • 请求混合后,prefill 和 decode 是否互相干扰
  • 当前是 compute-bound 还是 memory-bound

前置知识

阅读本文前,建议你至少熟悉这些概念:

  • Transformer 自注意力机制
  • 自回归生成
  • GPU 显存与带宽的基本区别
  • FP16 / BF16 / INT8 / INT4 的基本含义
  • Python 基础,能运行 Hugging Face 相关代码

不要求你自己写 CUDA,但需要知道:很多优化不是改模型结构,而是改“怎么喂给 GPU”


环境准备

下面的示例以 Python 为主,建议准备:

  • Python 3.10+
  • PyTorch 2.2+
  • transformers 4.44+
  • accelerate
  • bitsandbytes
  • vllm(用于高吞吐部署实验)
  • 一张支持 CUDA 的 NVIDIA GPU

安装示例:

pip install torch transformers accelerate bitsandbytes vllm

如果你在本地卡上实验,优先选一个体量适中的指令模型,避免环境问题掩盖核心结论。


核心原理

这一节我们先把三件事讲透:KV Cache、量化、批处理。理解它们之间的关系,比记住几个命令更重要。

1. 推理阶段拆解:Prefill 与 Decode

自回归生成大致分两段:

  1. Prefill

    • 把整段输入一次性过模型
    • 生成第一批 attention 所需的 K/V
    • 通常计算量大,但并行度高
  2. Decode

    • 每次基于已有上下文再生成一个 token
    • 使用缓存的 K/V,避免重复计算全部历史 token
    • 每步计算小,但访问缓存频繁

可以把它理解成:

  • Prefill 像“把书先翻一遍做索引”
  • Decode 像“后面每问一个问题,就拿着索引快速查”

如果没有 KV Cache,decode 时每个新 token 都要把整个历史重新算一遍,复杂度会急剧变差。

flowchart LR
    A[输入 Prompt] --> B[Prefill: 全量计算注意力]
    B --> C[生成 KV Cache]
    C --> D[Decode Step 1]
    D --> E[复用历史 KV]
    E --> F[Decode Step 2]
    F --> G[持续生成直到结束]

2. KV Cache 为什么重要

在 Transformer 的每一层里,历史 token 会形成对应的 Key 和 Value。生成下一个 token 时,Query 只需要与历史 K/V 做注意力计算。

所以 KV Cache 的本质是:

  • 用显存换算力
  • 避免重复计算历史上下文的 K/V
  • 极大降低 decode 的重复开销

但它的副作用也很明显:

  • 上下文越长,KV Cache 越大
  • batch 越大,KV Cache 成比例增长
  • 长对话、多轮会话、流式输出会持续占用显存

一个常用的近似估算方式:

KV Cache 大小 ≈ 2 × 层数 × hidden_size × 序列长度 × batch_size × bytes_per_element

这里的 2 表示 K 和 V 两份缓存。

工程上你会发现: 模型权重不一定是最大头,长上下文 + 多并发时,KV Cache 常常才是显存杀手

3. 量化为什么有时“省了显存,但没快多少”

很多同学第一次上量化,预期是:

  • 显存减半
  • 吞吐翻倍

现实通常没这么理想。原因在于:

  • 权重量化主要减少模型参数存储与访存
  • KV Cache 不一定同步量化
  • decode 阶段往往是内存带宽瓶颈,不完全是算术瓶颈
  • 某些量化实现存在反量化开销
  • 某些硬件对低比特算子支持不充分

所以量化的收益要分开看:

  • 权重量化:降低模型加载显存、提升可部署模型尺寸
  • 激活/Cache 量化:进一步压缩运行中内存
  • 端到端速度提升:取决于框架、算子融合、硬件支持、batch 结构

4. 批处理为什么能提吞吐,但也会拉高延迟

批处理的核心思想是:把多个请求凑成一个 batch,让 GPU 吃得更饱

好处:

  • 提高设备利用率
  • 降低单请求摊销成本
  • 提升整体 tokens/s

代价:

  • 请求需要等待凑批
  • 长短请求混合会产生 padding 或调度不均
  • 尾延迟可能上升

工程里最常用的不是“固定 batch”,而是:

  • 动态批处理
  • 按 token 数而不是按请求数控制 batch
  • 将 prefill 与 decode 分开调度
sequenceDiagram
    participant U1 as 请求1
    participant U2 as 请求2
    participant S as 调度器
    participant G as GPU

    U1->>S: 到达
    U2->>S: 到达
    S->>G: 组成动态批次进行 Prefill
    G-->>S: 返回首 token 所需状态
    S->>G: 组成 Decode 批次
    G-->>S: 连续生成 token
    S-->>U1: 返回结果流
    S-->>U2: 返回结果流

从工程视角看三类瓶颈

在真正优化前,我建议先判断自己主要卡在哪一类。

1. 显存瓶颈

表现:

  • 模型刚加载就接近满显存
  • 并发稍高就 OOM
  • 上下文一变长就崩

优先手段:

  • 权重量化
  • 限制最大上下文长度
  • 限制同时活跃 session
  • 使用支持更高效 KV 管理的推理框架

2. 带宽瓶颈

表现:

  • decode 阶段 tokens/s 很低
  • GPU 算力没有完全打满
  • 长上下文性能恶化明显

优先手段:

  • 降低 KV Cache 占用
  • 使用 paged attention / 更高效 cache 管理
  • 提升 batch 结构质量
  • 尽量减少无意义长上下文

3. 调度瓶颈

表现:

  • 单请求性能还行,但一到线上就抖
  • 吞吐不稳定
  • 尾延迟非常难看

优先手段:

  • 动态批处理
  • prefill / decode 分离
  • 请求分类队列
  • 限流、超时和最大输出长度控制

实战代码(可运行)

这一节我给出两个层次的示例:

  1. 用 Transformers 演示 KV Cache 与量化的基本用法
  2. 用 vLLM 演示更贴近生产的高吞吐推理

说明:代码可运行,但你需要根据本机显卡和模型权限,替换成合适模型。


实战一:用 Transformers 验证 KV Cache 的效果

示例目标

  • 对比 use_cache=True/False
  • 测试生成耗时
  • 观察长上下文下的差异
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True
)
model.eval()

prompt = "请用通俗语言解释什么是 KV Cache,并举一个生活中的类比。" * 20
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

def benchmark(use_cache: bool, max_new_tokens: int = 64, warmup: int = 1, runs: int = 3):
    for _ in range(warmup):
        _ = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=use_cache
        )
        if device == "cuda":
            torch.cuda.synchronize()

    times = []
    for _ in range(runs):
        start = time.time()
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=use_cache
        )
        if device == "cuda":
            torch.cuda.synchronize()
        times.append(time.time() - start)

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return sum(times) / len(times), text

avg_cache, _ = benchmark(True)
avg_no_cache, _ = benchmark(False)

print(f"use_cache=True  平均耗时: {avg_cache:.3f}s")
print(f"use_cache=False 平均耗时: {avg_no_cache:.3f}s")
print(f"加速比: {avg_no_cache / avg_cache:.2f}x")

你应该关注什么

不要只看“总耗时”,最好再看:

  • 输入 token 数
  • 输出 token 数
  • TTFT
  • 平均 decode token latency

如果只是粗暴比较总耗时,容易把 prefill 和 decode 混在一起,结论会失真。


实战二:4bit 量化加载模型

如果你的 GPU 显存比较紧,4bit 是很常见的第一步。下面用 bitsandbytes 做一个简单示例。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_NAME = "Qwen/Qwen2-7B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

prompt = "请总结量化推理的优点、代价,以及适用场景。"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=128,
    do_sample=False,
    use_cache=True
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

这段代码的实际意义

它解决的是:

  • 模型装得下
  • 单卡可部署更大模型
  • 显存留更多给 KV Cache 和 batch

它不保证一定“显著更快”。如果你上线前要做选择,我建议至少对比这三组:

  • FP16/BF16
  • INT8
  • INT4

并分别观察:

  • 首 token 延迟
  • decode tokens/s
  • 显存占用
  • 任务精度损失

实战三:计算 KV Cache 的近似占用

很多问题不是“感觉显存不够”,而是没有提前估算。下面给一个简单估算函数。

def estimate_kv_cache_gb(
    batch_size: int,
    seq_len: int,
    num_layers: int,
    hidden_size: int,
    bytes_per_elem: int = 2
) -> float:
    total_bytes = 2 * batch_size * seq_len * num_layers * hidden_size * bytes_per_elem
    return total_bytes / (1024 ** 3)

examples = [
    {"batch_size": 1, "seq_len": 4096, "num_layers": 32, "hidden_size": 4096},
    {"batch_size": 8, "seq_len": 4096, "num_layers": 32, "hidden_size": 4096},
    {"batch_size": 16, "seq_len": 8192, "num_layers": 32, "hidden_size": 4096},
]

for x in examples:
    gb = estimate_kv_cache_gb(**x)
    print(f"{x} => 约 {gb:.2f} GB")

怎么用这个估算结果

当你准备把并发从 8 拉到 32 时,先别直接上线。先问自己:

  • 权重已经占多少显存?
  • 还要预留多少给激活、碎片、框架开销?
  • 峰值上下文长度到底是多少?
  • 多轮会话会不会让 cache 长时间不释放?

我自己做容量评估时,通常不会把显存算到 95% 才放心,而是会留出更保守的安全边界。


实战四:用 vLLM 进行高吞吐推理

如果你要更贴近生产场景,vLLM 是非常值得试的。它的价值不只是“换个 API”,而是它在 KV 管理、调度、动态批处理 方面做了很多工程优化。

启动服务

python -m vllm.entrypoints.openai.api_server \
  --model Qwen/Qwen2-7B-Instruct \
  --dtype float16 \
  --max-model-len 4096 \
  --gpu-memory-utilization 0.9

用 OpenAI 兼容接口调用

from openai import OpenAI

client = OpenAI(
    api_key="EMPTY",
    base_url="http://127.0.0.1:8000/v1"
)

resp = client.chat.completions.create(
    model="Qwen/Qwen2-7B-Instruct",
    messages=[
        {"role": "system", "content": "你是一个简洁的技术助手。"},
        {"role": "user", "content": "解释 KV Cache 如何影响大模型推理吞吐。"}
    ],
    temperature=0.0,
    max_tokens=128
)

print(resp.choices[0].message.content)

压测脚本示例

下面给一个简单并发压测脚本,用来粗看吞吐。

import time
import threading
from openai import OpenAI

client = OpenAI(
    api_key="EMPTY",
    base_url="http://127.0.0.1:8000/v1"
)

N = 16
results = []

def worker(i):
    start = time.time()
    resp = client.chat.completions.create(
        model="Qwen/Qwen2-7B-Instruct",
        messages=[
            {"role": "user", "content": f"请用 100 字解释批处理对推理吞吐的影响。请求编号{i}"}
        ],
        temperature=0.0,
        max_tokens=128
    )
    elapsed = time.time() - start
    results.append((i, elapsed, len(resp.choices[0].message.content)))

threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
start_all = time.time()

for t in threads:
    t.start()
for t in threads:
    t.join()

total = time.time() - start_all

for item in sorted(results):
    print(f"req={item[0]}, latency={item[1]:.3f}s, chars={item[2]}")
print(f"total time={total:.3f}s, qps≈{N/total:.2f}")

为什么这比单纯 Transformers 更适合服务化

因为服务化问题的关键不只是“生成”,而是:

  • 多请求如何共享 GPU
  • 不同长度请求如何调度
  • KV Cache 如何分页管理
  • 如何降低内存碎片
  • 如何让吞吐和延迟达到可接受平衡
flowchart TD
    A[请求进入] --> B[按到达时间进入队列]
    B --> C[动态批处理]
    C --> D[Prefill 调度]
    D --> E[KV Cache 分配/分页管理]
    E --> F[Decode 调度]
    F --> G[流式返回]
    E --> H[请求结束后回收缓存]

逐步验证清单

我建议不要一上来就做“大而全优化”,而是按下面顺序逐步验证。

第一步:建立基线

记录最基础指标:

  • 模型版本
  • 精度类型
  • 最大上下文
  • 单请求 TTFT
  • 单请求 tokens/s
  • 并发 1/4/8/16 下吞吐与 P95 延迟
  • 显存占用峰值

第二步:只开 KV Cache

验证:

  • decode latency 是否明显下降
  • 显存上升是否可接受

第三步:只开量化

验证:

  • 模型是否更容易装下
  • 精度退化是否在业务容忍范围内
  • decode 吞吐是否真的提升,而不是“只有显存下降”

第四步:引入动态批处理

验证:

  • 总吞吐是否上升
  • P95/P99 是否恶化
  • 长短请求混合时是否出现明显抖动

第五步:组合优化

验证组合效果,而不是单项最佳:

  • KV Cache + 量化
  • KV Cache + 动态批处理
  • 量化 + 更长上下文
  • 全部叠加后的稳定性

常见坑与排查

这部分非常重要。很多优化失败,不是原理不对,而是踩了实现细节。

坑 1:量化后显存降了,但速度没提升

现象:

  • 4bit 能跑更大模型
  • 但 tokens/s 并没有明显变快

原因可能是:

  • 当前瓶颈在 KV Cache 或内存带宽,不在权重访存
  • 低比特算子没有被硬件高效支持
  • 反量化开销抵消了一部分收益
  • batch 太小,吞吐提升不明显

排查建议:

  • 对比不同 batch、不同上下文长度下的性能
  • 记录 prefill 和 decode 各自耗时
  • 检查框架是否真的启用了高效量化内核

坑 2:开了 batch,P95 延迟变差很多

现象:

  • 总体 tokens/s 上升
  • 但用户抱怨“有时特别慢”

常见原因:

  • 凑批等待时间过长
  • 长请求拖住短请求
  • prefill 与 decode 混跑互相干扰

排查建议:

  • 给请求按输入长度分桶
  • 限制单批最大 token 数
  • 将极长请求放到单独队列
  • 监控队列等待时间,而不是只看模型计算时间

坑 3:明明设置了最大上下文,还是 OOM

现象:

  • 理论上显存够
  • 实际运行却偶发 OOM

原因可能是:

  • 多轮会话缓存未及时释放
  • CUDA 内存碎片
  • 框架额外缓存和中间张量
  • 突发并发使活跃 token 数超预期

排查建议:

  • 监控活跃 session 数与活跃 token 数
  • 主动设置更保守的 max_model_len
  • 定期观察 torch.cuda.memory_summary()
  • 对长会话做超时回收或截断策略

坑 4:首 token 很慢,但后续输出还行

现象:

  • TTFT 高
  • decode token latency 正常

通常说明:

  • prefill 成本太高
  • 输入 prompt 太长
  • 检索拼接上下文过多
  • 系统 prompt 设计臃肿

解决思路:

  • 精简系统提示词
  • 对 RAG 上下文做截断与重排
  • 缓存公共前缀 prompt
  • 复用 prefix cache(如果框架支持)

坑 5:同样的模型,离线压测和线上表现差很多

原因常见于:

  • 线上请求长度分布更极端
  • 用户输出长度不可控
  • 多租户流量导致负载抖动
  • 流式输出、网络开销、日志采样影响端到端延迟

建议:

  • 用真实流量分布回放
  • 不要只做“固定 512 输入 + 128 输出”的理想压测
  • 线上线下统一指标口径
stateDiagram-v2
    [*] --> 基线测试
    基线测试 --> 单项优化验证
    单项优化验证 --> 组合优化
    组合优化 --> 压测观察
    压测观察 --> 上线灰度
    上线灰度 --> 稳定运行
    压测观察 --> 回退排查
    上线灰度 --> 回退排查
    回退排查 --> 单项优化验证

安全/性能最佳实践

这一节我把比较实用的建议收敛成清单,适合上线前逐项核对。

1. 对输出长度设置硬上限

这是最简单、最有效的稳态控制手段之一。

建议控制:

  • max_new_tokens
  • 每租户并发上限
  • 总活跃请求数
  • 总活跃 token 数

否则某些“超长输出”请求会把整个服务拖垮。


2. 对输入长度做分层治理

不要让所有请求都走同一条路径。可以分为:

  • 短请求:实时交互优先
  • 中等请求:正常动态批处理
  • 超长请求:低优先级或单独队列

这样可以显著减少长尾延迟污染。


3. 优先监控 token 级指标,而不是只看 QPS

大模型服务里,QPS 很容易误导。真正更有效的是:

  • TTFT
  • TPOT(time per output token)
  • 输入 token 总量
  • 输出 token 总量
  • 活跃 KV token 数
  • P50/P95/P99 延迟
  • OOM 次数
  • 请求取消率

4. 为量化保留回退方案

量化收益很大,但并不总是稳定。上线时最好保留:

  • FP16/BF16 版本
  • INT8 版本
  • INT4 版本

当你发现:

  • 特定任务精度下降明显
  • 特定 GPU 上性能异常
  • 某版本驱动/内核不兼容

就能快速切回。


5. 谨慎对待“极限显存利用率”

我个人不建议把 gpu-memory-utilization 或显存占用打得太满。因为线上环境会有:

  • 流量抖动
  • 请求长度抖动
  • 内存碎片
  • 框架缓存波动

保守一些,往往比“纸面吞吐更高”更值钱。


6. 把优化目标拆成两个,而不是一个

常见误区是只追一个数字,比如“tokens/s 越高越好”。

实际上更合理的是同时设两个目标:

  • 用户体验目标:TTFT、P95 latency
  • 资源效率目标:总 tokens/s、单卡成本

因为很多优化能提升吞吐,但会明显伤害交互体验。你的系统是 API 服务还是离线批处理,取舍会不同。


7. 对公共前缀做缓存

如果你的业务有稳定的系统提示词、模板化上下文、固定工具定义,那么前缀缓存非常值得做。

适用场景:

  • 统一 system prompt
  • 多轮对话带相同角色设定
  • RAG 前有固定指令模板
  • 代码助手中常见规则模板

这类优化对 TTFT 往往很有帮助。


一个可落地的优化决策顺序

如果你问我:“工程上最实用的顺序是什么?”我通常会这么做:

  1. 先测基线

    • 不然你根本不知道优化值不值
  2. 先上 KV Cache

    • 这是 decode 提速的基本盘
  3. 再做权重量化

    • 优先解决“装不下”和“显存太紧”的问题
  4. 再做动态批处理

    • 重点提升整体吞吐
  5. 最后调队列、限流、请求分桶

    • 这是把线上波动压下来的关键

也就是说,先解决“能稳定跑”,再解决“跑得更省”。很多团队反过来,一开始就盯着极限吞吐,最后反而被稳定性拖住。


总结

把大模型推理优化落到工程上,可以抓住一句话:

KV Cache 解决重复计算,量化解决显存压力,批处理解决吞吐效率,而真正的上线效果取决于调度与边界控制。

你可以把本文的重点记成这几条:

  • KV Cache 是 decode 提速核心,但会显著占用显存
  • 量化首先解决“装得下”,其次才是“跑更快”
  • 批处理能提升吞吐,但不加约束会伤害尾延迟
  • 线上优化要按 token 和阶段拆指标,不要只看总耗时
  • prefill、decode、KV 管理、调度策略必须一起看

如果你正在做第一版优化,我建议从下面这份最小闭环开始:

  1. 建立单请求与并发基线
  2. 开启 KV Cache,对比 decode 性能
  3. 尝试 4bit/8bit 量化,对比显存与精度
  4. 用 vLLM 做动态批处理压测
  5. 加上最大输入、最大输出、活跃 token 上限
  6. 用真实流量分布复测 P95/P99

这样走下来,你基本就能从“模型能跑”提升到“服务可用、成本可控”。

如果只给一句最终建议,那就是:

别把推理优化当成单点技术,而要把它当成“显存、带宽、调度、请求分布”共同作用的系统工程。


分享到:

上一篇
《Java开发踩坑实战:排查并修复线程池误用导致的接口雪崩与内存飙升》
下一篇
《Java 中线程池参数调优与异步任务治理实战指南》