跳转到内容
123xiao | 无名键客

《大模型推理性能实战优化:从 KV Cache、量化到批处理调度的工程方法》

字数: 0 阅读时长: 1 分钟

大模型推理性能实战优化:从 KV Cache、量化到批处理调度的工程方法

做大模型应用时,很多团队一开始都盯着“模型参数量”,但真正上线后,最先把人打懵的往往不是模型大小,而是推理性能:首 token 慢、吞吐低、显存不够、并发一上来就抖。

我自己做这类系统时,最深的感受是:推理优化不是某一个点的魔法,而是一整条链路的工程配合。你可能开了 KV Cache,却发现批处理没做好;你可能做了量化,却因为访存和调度问题没拿到预期收益;你可能吞吐上去了,但 P99 延迟反而更差。

这篇文章不谈太多“论文式原理”,而是从工程实战角度,带你把几个最常用、最有效的优化手段串起来:

  • KV Cache: 减少重复计算,尤其提升长上下文生成效率
  • 量化: 用更少显存和带宽跑更大的模型
  • 批处理调度: 在延迟与吞吐之间找到平衡点

目标是:你看完后,能自己搭一个可运行的小实验,知道先优化哪里、怎么测、出了问题怎么查。


背景与问题

先看一个典型的大模型推理请求过程。

用户输入一段 prompt,模型先做一遍 prefill(把已有上下文全部编码进注意力状态),再进入 decode(每次生成一个 token,再接着下一个)。

问题主要出在这几处:

  1. prefill 很吃算力

    • 输入越长,首次计算越重
    • 长 prompt 会拉高首 token 延迟
  2. decode 很吃访存

    • 每一步虽然只生成 1 个 token,但要不断读取历史状态
    • 并发高时,显存带宽成为瓶颈
  3. 请求长度不均匀

    • 有人问一句话,有人扔 8K prompt
    • 如果简单地“凑 batch”,容易让短请求陪长请求一起等
  4. 显存是硬约束

    • 模型权重、KV Cache、中间激活都要占显存
    • 一旦配置不合理,直接 OOM

所以,推理优化不能只盯一个指标。至少要同时关心:

  • TTFT(Time To First Token):首 token 延迟
  • TPOT(Time Per Output Token):每个输出 token 平均耗时
  • Throughput:吞吐,通常是 tokens/s 或 req/s
  • P95/P99 延迟:尾延迟
  • 显存占用:是否能稳定跑起来

前置知识与环境准备

这篇文章的代码示例用 Python,重点是帮助你建立“可验证”的优化思路,而不是绑死某个框架。

建议环境:

  • Python 3.10+
  • PyTorch 2.x
  • 一张 NVIDIA GPU(没有也能跑部分示例,但性能实验意义有限)

安装示例:

pip install torch transformers accelerate psutil

如果你想进一步做量化实验,可按需要安装:

pip install bitsandbytes

核心原理

这一节把三件事讲清楚:KV Cache 为什么快、量化为什么省、批处理调度为什么难。

1. KV Cache:省掉重复计算

Transformer 的自注意力在生成阶段,本质上每生成一个新 token,都需要跟历史 token 建立关联。

如果不使用 KV Cache,每次 decode 都要把“历史上下文”重新算一遍,代价非常高。
如果使用 KV Cache,历史 token 对应的 Key/Value 可以缓存起来,下一步只需要计算新 token 的 Query,再与缓存的 K/V 交互。

简单理解:

  • 不开 Cache:每一步都“重读整本书”
  • 开 Cache:前面读过的内容做了索引,后面只查索引

2. 量化:减小权重与访存压力

大模型推理时,很多时候不是纯算力瓶颈,而是显存带宽瓶颈
尤其 decode 阶段,经常表现为“算得不算慢,但搬数据太慢”。

量化的核心思想是:

  • 把 FP16 / BF16 权重压缩成 INT8、INT4 等更低比特形式
  • 减少模型参数占用
  • 降低访存带宽消耗
  • 让更大的模型装进有限显存

但量化不是白送的,它会带来:

  • 精度损失风险
  • 某些硬件/算子支持不一致
  • 某些场景速度提升明显,某些场景不明显

3. 批处理调度:吞吐与延迟的平衡术

批处理是服务端推理吞吐提升的关键。
多个请求一起算,GPU 利用率会更高。

但问题在于:请求不是整齐划一的。

  • 输入长度不同
  • 输出长度不同
  • 到达时间不同

如果你用最简单的静态 batch:

  • batch 内最慢的请求,会拖住其他请求
  • 长尾请求会放大尾延迟
  • 短请求体验变差

所以实际工程里常做的是:

  • 动态批处理(dynamic batching)
  • 按长度分桶(bucketing)
  • prefill / decode 分阶段调度
  • 连续批处理(continuous batching)

一图看懂三种优化点的关系

flowchart LR
    A[用户请求到达] --> B[Tokenizer]
    B --> C[Prefill阶段]
    C --> D[KV Cache建立]
    D --> E[Decode阶段]
    E --> F[输出Token]

    G[量化权重] --> C
    G --> E

    H[批处理调度] --> C
    H --> E

    I[显存限制] --> D
    I --> G
    I --> H

这张图很重要:
KV Cache、量化、批处理调度不是三条平行线,而是互相牵制的。

举个例子:

  • 量化后,权重显存下降,可能允许更大 batch
  • batch 变大后,KV Cache 占用可能又上来了
  • KV Cache 变大后,尾延迟可能因调度不当而恶化

从请求生命周期理解性能瓶颈

sequenceDiagram
    participant U as User
    participant S as Scheduler
    participant M as Model
    participant C as KV Cache

    U->>S: 提交请求(prompt)
    S->>M: prefill batch
    M->>C: 写入历史K/V
    M-->>S: 首token
    loop 每轮decode
        S->>M: 取活跃请求组成decode batch
        M->>C: 读取已有K/V并追加新K/V
        M-->>S: 返回下一个token
    end
    S-->>U: 完整输出

这也解释了一个常见现象:

  • 首 token 慢:prefill 重
  • 后续 token 稳定但不够快:decode 受 KV 读取与调度影响
  • 并发高时抖动:调度策略让不同长度请求相互拖累

实战代码(可运行)

下面做一个“小而完整”的实验,帮助你验证三件事:

  1. 开启 KV Cache 是否有帮助
  2. 不同 batch 规模对吞吐的影响
  3. 如何做一个最简单的动态调度模拟

说明:示例使用 distilgpt2,因为它轻量、容易跑通。
如果你有更强的 GPU,可以换成更大的因果语言模型。


实战一:测量 KV Cache 对生成速度的影响

import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "distilgpt2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

prompt = "请用通俗的语言解释什么是 Transformer,并给一个实际例子。" * 8
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

def benchmark_generate(use_cache: bool, max_new_tokens: int = 64, warmup: int = 1, runs: int = 3):
    # 预热
    with torch.no_grad():
        for _ in range(warmup):
            _ = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                use_cache=use_cache,
                do_sample=False
            )

    times = []
    with torch.no_grad():
        for _ in range(runs):
            if DEVICE == "cuda":
                torch.cuda.synchronize()
            start = time.perf_counter()

            _ = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                use_cache=use_cache,
                do_sample=False
            )

            if DEVICE == "cuda":
                torch.cuda.synchronize()
            end = time.perf_counter()
            times.append(end - start)

    avg = sum(times) / len(times)
    print(f"use_cache={use_cache}, avg_time={avg:.4f}s, runs={times}")

benchmark_generate(use_cache=False)
benchmark_generate(use_cache=True)

你应该观察什么?

  • 在稍长 prompt、稍长生成长度下,use_cache=True 通常更快
  • 如果模型很小、输入很短,差异可能不算夸张
  • 真正的大模型、长上下文场景里,收益通常更明显

为什么这个实验有意义?

因为它把“KV Cache 有用”从概念变成了可测量事实。
做性能优化时,我很建议你先做这种最小闭环验证,别一上来就改一堆配置,最后不知道是谁起作用。


实战二:估算 KV Cache 显存占用

很多人第一次把并发开大,直接 OOM。根因往往不是模型权重,而是低估了 KV Cache 体积

一个常见的粗略估算公式:

KV Cache 大小 ≈ 2 × 层数 × batch × 序列长度 × hidden_size × 每元素字节数

其中乘以 2 是因为有 K 和 V。

下面写一个简单估算器:

def estimate_kv_cache_bytes(
    num_layers: int,
    batch_size: int,
    seq_len: int,
    hidden_size: int,
    bytes_per_element: int = 2
):
    return 2 * num_layers * batch_size * seq_len * hidden_size * bytes_per_element

def pretty_size(num_bytes: int):
    units = ["B", "KB", "MB", "GB", "TB"]
    size = float(num_bytes)
    for unit in units:
        if size < 1024:
            return f"{size:.2f} {unit}"
        size /= 1024
    return f"{size:.2f} PB"

examples = [
    {"num_layers": 24, "batch_size": 8, "seq_len": 2048, "hidden_size": 2048, "bytes_per_element": 2},
    {"num_layers": 32, "batch_size": 16, "seq_len": 4096, "hidden_size": 4096, "bytes_per_element": 2},
]

for e in examples:
    size = estimate_kv_cache_bytes(**e)
    print(e, "=>", pretty_size(size))

这个公式为什么只是“粗略估算”?

因为真实实现里还会受到这些因素影响:

  • 多头维度划分方式
  • 分页 KV Cache / 连续内存实现
  • 对齐与额外元数据
  • 张量并行、流水并行带来的副本或切分

但它足够帮你做第一轮容量规划。
工程里我一般会先用粗估算筛配置,再用实际 profiling 校正。


实战三:量化加载模型并比较显存

如果你的环境支持 bitsandbytes,可以试试 8bit 量化。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "distilgpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def print_cuda_mem(tag: str):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**2
        reserved = torch.cuda.memory_reserved() / 1024**2
        print(f"[{tag}] allocated={allocated:.2f} MB, reserved={reserved:.2f} MB")

if DEVICE == "cuda":
    torch.cuda.empty_cache()
    print_cuda_mem("before fp16")

model_fp = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
).to(DEVICE)
model_fp.eval()
print_cuda_mem("after fp16/fp32")

del model_fp
if DEVICE == "cuda":
    torch.cuda.empty_cache()

try:
    from transformers import BitsAndBytesConfig

    quant_config = BitsAndBytesConfig(load_in_8bit=True)
    model_int8 = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=quant_config,
        device_map="auto"
    )
    model_int8.eval()
    print_cuda_mem("after int8")
except Exception as e:
    print("8bit quantization load failed:", e)
    print("请确认已安装 bitsandbytes,且 CUDA/平台兼容。")

看结果时别只看“能不能跑”

要看这几个点:

  • 显存下降了多少
  • 生成速度是否真的变快
  • 输出质量是否明显变差
  • 某些平台是否兼容失败

我踩过一个坑:有时量化后显存确实省了,但速度收益并不明显。
原因可能不是量化“没用”,而是你的瓶颈在别处,比如:

  • batch 太小,GPU 吃不满
  • prompt 太短,优化空间有限
  • 某些算子没有走到高效实现
  • 调度不合理,等待时间盖过了计算收益

实战四:用一个简化版调度器理解动态批处理

下面不直接搭完整推理服务,而是先写一个调度模拟器
它不会真的跑模型,但能帮助你理解“不同长度请求混在一起”为什么影响性能。

import random
import heapq
from dataclasses import dataclass, field

@dataclass(order=True)
class Request:
    arrival_time: float
    input_len: int = field(compare=False)
    output_len: int = field(compare=False)
    id: int = field(compare=False, default=0)

def simulate_scheduler(
    requests,
    max_batch_size=4,
    batching_window=0.01,
    prefill_cost_per_token=0.00005,
    decode_cost_per_token=0.00002
):
    """
    一个极简调度模型:
    - prefill按输入token数线性计时
    - decode按每轮活跃请求数计时
    - 同一批次的prefill一起做
    """
    requests = sorted(requests, key=lambda r: r.arrival_time)
    time_now = 0.0
    idx = 0
    finished = {}

    while idx < len(requests):
        first = requests[idx]
        time_now = max(time_now, first.arrival_time)
        batch = [first]
        idx += 1

        # batching window 内尽量收集请求
        while idx < len(requests) and len(batch) < max_batch_size:
            if requests[idx].arrival_time <= time_now + batching_window:
                batch.append(requests[idx])
                idx += 1
            else:
                break

        # prefill:由本批最大输入长度主导
        max_input_len = max(r.input_len for r in batch)
        time_now += max_input_len * prefill_cost_per_token

        # decode:每轮生成1个token,直到所有请求完成
        remaining = {r.id: r.output_len for r in batch}
        while remaining:
            active_count = len(remaining)
            time_now += active_count * decode_cost_per_token

            done_ids = []
            for rid in remaining:
                remaining[rid] -= 1
                if remaining[rid] <= 0:
                    done_ids.append(rid)

            for rid in done_ids:
                finished[rid] = time_now
                del remaining[rid]

    return finished

def build_requests(n=12, seed=42):
    random.seed(seed)
    reqs = []
    t = 0.0
    for i in range(n):
        t += random.uniform(0.0, 0.02)
        reqs.append(
            Request(
                id=i,
                arrival_time=t,
                input_len=random.choice([32, 64, 128, 256, 512]),
                output_len=random.choice([16, 32, 64, 128]),
            )
        )
    return reqs

requests = build_requests()

for batch_size in [1, 2, 4]:
    result = simulate_scheduler(
        requests,
        max_batch_size=batch_size,
        batching_window=0.01
    )
    avg_finish = sum(result.values()) / len(result)
    print(f"max_batch_size={batch_size}, avg_finish_time={avg_finish:.4f}s")

这段代码要理解什么?

  • batch_size=1:延迟可能更低,但吞吐差
  • batch_size 变大:吞吐通常更好,但短请求可能被拖慢
  • 如果输入长度差异很大,混批代价会更明显

它虽然是简化模型,但足够帮你建立调度直觉。


调度策略怎么选:一个实用思路

如果你现在要做生产服务,通常不会问“最优策略是什么”,而是问:

在我的流量分布、模型大小、延迟目标下,先上哪种策略最划算?

我的建议是分三层做。

第一层:先做最小有效优化

  • 开启 KV Cache
  • 按输入长度做简单分桶
  • 设置一个小的动态 batching window(例如 5ms~20ms)
  • 对大模型尝试 8bit / 4bit 量化

这一层投入小、收益通常直接。

第二层:根据业务目标调参

如果你更看重首 token 体验

  • 减小 batching window
  • 限制 batch size
  • 对超长 prompt 做单独队列

如果你更看重整体吞吐

  • 增大 batch size
  • 提高请求聚合时间
  • 使用 continuous batching

第三层:做分阶段调度

把请求拆成:

  • prefill 阶段
  • decode 阶段

两者分开调度,因为它们的资源特征不同:

  • prefill 更偏算力密集
  • decode 更偏显存带宽与缓存访问
stateDiagram-v2
    [*] --> Waiting
    Waiting --> PrefillQueued: 请求到达
    PrefillQueued --> Decoding: 完成prefill并建立KV Cache
    Decoding --> Decoding: 继续生成token
    Decoding --> Finished: 命中EOS或达到max_tokens
    Finished --> [*]

这也是很多高性能推理框架会采用的思路。


常见坑与排查

这一节我尽量写得“像现场”,因为很多问题不是理论不会,而是线上一跑就歪。

坑 1:开了 KV Cache,但速度几乎没变

可能原因:

  1. 输入很短、输出很短,收益本来就有限
  2. 模型太小,框架调度和 Python 开销占比高
  3. 实际生成路径没有真正启用 cache
  4. GPU 没有被吃满,瓶颈不在注意力重复计算

排查方法:

  • 打印生成配置,确认 use_cache=True
  • 用更长 prompt 和更长输出复测
  • 用 profiler 看时间花在哪
  • 对比 prefill 与 decode 分阶段耗时

坑 2:量化后显存降了,但延迟反而升高

可能原因:

  1. 量化算子实现不够高效
  2. 某些层仍在高精度上运行
  3. 小 batch 下,量化收益被其他开销掩盖
  4. CPU/GPU 间有额外数据搬运

排查方法:

  • 检查模型实际落在哪些设备上
  • 检查是否发生了 host-device copy
  • 对比不同 batch size 下的性能曲线
  • 观察 tokens/s 而不是只看单次 wall time

坑 3:批处理一开,P99 变差很多

可能原因:

  1. batching window 过大
  2. 长短请求混批严重
  3. 超长请求占据解码轮次,拖慢短请求
  4. 队列缺少优先级或长度分桶

排查方法:

  • 把请求按输入长度分桶统计
  • 单独看短请求与长请求的 P95/P99
  • 记录队列等待时间与模型计算时间
  • 调小 batch 或拆分队列验证

坑 4:显存明明“理论够”,实际还是 OOM

可能原因:

  1. KV Cache 增长被低估
  2. 框架预留显存、缓存池占用较大
  3. batch 内最长序列决定了实际张量尺寸
  4. 多个并发请求在峰值时叠加

排查方法:

import torch

def report_cuda_memory():
    if torch.cuda.is_available():
        print("allocated(MB):", torch.cuda.memory_allocated() / 1024**2)
        print("reserved(MB):", torch.cuda.memory_reserved() / 1024**2)
        print("max_allocated(MB):", torch.cuda.max_memory_allocated() / 1024**2)

report_cuda_memory()

同时建议你记录:

  • 当前 batch size
  • 最大输入长度
  • 最大生成长度
  • 活跃请求数
  • 每轮 decode 的 token 数

这几个指标一对照,OOM 往往就不神秘了。


安全/性能最佳实践

这里把“能跑”提升到“能稳定上线”。

1. 不要只测平均值,要盯尾延迟

至少同时记录:

  • 平均延迟
  • P95 / P99
  • TTFT
  • tokens/s
  • GPU 显存与利用率

平均值很好看,线上用户照样骂你,通常就是因为尾延迟。


2. 先做容量预算,再开并发

上线前至少估算:

  • 模型权重占用
  • KV Cache 峰值
  • batch 变化范围
  • 最长上下文长度
  • 最长输出长度

如果业务允许,强烈建议设置:

  • 最大输入长度
  • 最大输出长度
  • 单请求超时
  • 单租户并发上限

这既是性能策略,也是安全策略。否则恶意超长 prompt 很容易拖垮服务。


3. 对请求做长度分层

一个非常朴素但有效的策略:

  • 短请求队列:优先低延迟
  • 长请求队列:接受更高等待
  • 超长上下文:单独处理

很多时候你不需要一上来就上复杂调度器,先分层,收益就很明显。


4. 量化前先确认目标

量化通常有三类目标:

  1. 显存不够,先装得下
  2. 带宽受限,希望更快
  3. 想用更大 batch 提升吞吐

如果你的目标不明确,很容易陷入“量化了但没有想象中快”的困惑。
量化是手段,不是 KPI。


5. 为 KV Cache 预留边界

不要把显存打到极限。实际工程里建议保留缓冲区,避免:

  • 突发长请求
  • 框架内部临时分配
  • 多流并发造成峰值抖动

经验上,预留 10%~20% 安全空间通常更稳。


6. 把 prefill 和 decode 分开观测

这是我非常建议做的一件事。因为这两个阶段优化手段并不完全相同:

  • prefill 慢:看 prompt 长度、batch、算力利用率
  • decode 慢:看 KV Cache、访存带宽、调度活跃请求数

如果你把它们混成一个总耗时,很难真正找到瓶颈。


逐步验证清单

如果你打算自己做一轮优化,我建议按这个顺序来,不容易乱。

第一步:建立基线

记录当前:

  • 单请求 TTFT
  • 单请求 tokens/s
  • batch=1 的显存占用
  • 高并发下 P95/P99

第二步:开启 KV Cache

验证:

  • 长上下文下是否明显加速
  • decode 阶段耗时是否下降
  • 显存是否按预期增长

第三步:尝试量化

验证:

  • 模型是否稳定加载
  • 显存下降多少
  • tokens/s 是否提升
  • 生成质量是否可接受

第四步:加动态 batching

验证:

  • 吞吐是否提升
  • TTFT 是否被拉长
  • 长短请求是否互相拖累

第五步:做长度分桶

验证:

  • P99 是否改善
  • 短请求体验是否明显变好
  • GPU 利用率是否仍可接受

一个可执行的工程建议组合

如果你现在手上是一个中等规模的在线生成服务,我会建议从下面这套组合开始:

  1. 默认开启 KV Cache
  2. 模型先尝试 8bit 量化
  3. batching window 控制在 5ms~10ms
  4. 按输入长度做 3 桶
    • 0~512
    • 513~2048
    • 2049+
  5. 为超长请求设置单独队列
  6. 监控 prefill / decode 分阶段耗时
  7. 预留 15% 显存安全边界

这套方案不一定“最先进”,但通常足够实用,也容易落地。


总结

大模型推理优化,真正有效的不是某个孤立技巧,而是三件事协同:

  • KV Cache:减少重复计算,重点改善生成阶段效率
  • 量化:降低权重与访存压力,换取更好的显存利用和潜在吞吐
  • 批处理调度:把 GPU 真正用起来,但要小心尾延迟

如果只记住一句话,我希望是:

先建立基线,再分阶段优化;先解决最大瓶颈,再谈复杂技巧。

具体落地时,可以按这个顺序:

  1. 先测基线
  2. 开 KV Cache
  3. 做量化
  4. 上动态 batching
  5. 再做长度分桶和分阶段调度

这样走,基本不会偏。

最后补一个边界条件:
如果你的业务是极低延迟、超短请求、并发也不高,那么复杂的批处理调度不一定值得,甚至可能适得其反。
反过来,如果你面对的是长上下文、高并发、多租户场景,那么 KV Cache + 量化 + 动态调度 几乎就是必修课。

性能优化从来不是“调个参数就结束”,更像是不断逼近系统边界的过程。
但好消息是,只要方法对,收益通常很实在,而且能被稳定复现。


分享到:

上一篇
《从浏览器抓包到参数还原:中级开发者实战 Web 逆向中的接口签名分析与复现》
下一篇
《Java 中使用 CompletableFuture 构建高并发异步流程的实战指南》