跳转到内容
123xiao | 无名键客

《大模型推理性能优化实战:从 KV Cache、量化到并发调度的系统化落地指南》

字数: 0 阅读时长: 1 分钟

大模型推理性能优化实战:从 KV Cache、量化到并发调度的系统化落地指南

做大模型推理优化,最怕两件事:
一是只盯着某个点,比如“上量化就快了”,结果整体吞吐并没有明显提升;
二是没有系统视角,线上一加并发就抖,P99 飙升,GPU 利用率却还不高。

我自己在做这类系统时,最大的感受是:推理优化不是单点技术,而是一条完整链路。从模型权重怎么存、KV Cache 怎么放、请求怎么排队,到 batch 怎么拼、长短请求怎么隔离,任何一个环节都可能成为瓶颈。

这篇文章不只讲概念,而是按“能落地”的方式,把一套常见的大模型推理优化路径串起来。重点覆盖:

  • 为什么推理会慢,慢在哪里
  • KV Cache 的收益和代价
  • 量化为什么有时提速,有时反而不明显
  • 并发调度如何影响吞吐和尾延迟
  • 一个可运行的简化版服务示例
  • 线上常见坑和排查思路
  • 安全与性能的最佳实践

背景与问题

大模型推理通常分成两个阶段:

  1. Prefill 阶段:把提示词一次性喂进去,计算历史上下文
  2. Decode 阶段:每次生成一个 token,并基于已有上下文继续算下去

这两个阶段的特征很不一样:

  • Prefill 更像大矩阵计算,吞吐高,但输入长度敏感
  • Decode 是逐 token 生成,依赖前一步结果,天然串行,延迟敏感

很多团队刚上模型服务时,常见现象是:

  • GPU 显存很快打满
  • 单请求延迟高,尤其长上下文请求
  • 并发一上来,P95/P99 迅速恶化
  • 吞吐没有随着 batch 增长线性提升
  • 量化后显存省了,但速度提升有限
  • 不同长度请求混跑时,短请求被长请求“拖死”

这些问题背后,本质上是几类资源争抢:

  • 算力:矩阵乘性能是否吃满
  • 显存:权重、激活、KV Cache 如何分配
  • 带宽:显存读写、PCIe/网络传输是否成为瓶颈
  • 调度:请求如何分组、何时抢占、何时复用

如果只优化某一个点,往往只能得到局部收益。真正可复制的方案,得从系统层面看。


先建立一个全局图:推理优化到底在优化什么

flowchart LR
    A[请求进入] --> B[Tokenizer/预处理]
    B --> C[Prefill]
    C --> D[KV Cache写入]
    D --> E[Decode循环]
    E --> F[采样/后处理]
    F --> G[响应返回]

    H[量化权重] --> C
    H --> E
    I[动态批处理] --> C
    I --> E
    J[并发调度] --> A
    K[显存管理] --> D
    K --> E

可以把常见优化手段归为四层:

层次目标典型手段
模型层降低计算/存储成本量化、蒸馏、结构裁剪
执行层提高单卡效率KV Cache、算子融合、Flash Attention
调度层提高多请求吞吐Continuous Batching、队列调度、优先级
系统层提升稳定性显存池、限流、熔断、观测与回收

本文重点讲后三层,因为它们最接近实际部署。


核心原理

1. KV Cache:把“重复计算”变成“读缓存”

在 Transformer 的自回归生成中,每个新 token 都需要关注前面所有 token。
如果每次 decode 都重新计算历史 token 的 K/V,成本会非常高。

KV Cache 的思路很直接:

  • 首次 prefill 时,算出每层 attention 的 K/V
  • 后续 decode 只计算新 token 的 Q/K/V
  • 历史 K/V 直接从缓存读

于是,原来“每步都重算全部上下文”,变成了“每步只追加一个 token 的状态”。

KV Cache 带来的收益

  • 显著降低 decode 阶段重复计算
  • 长上下文场景收益尤其明显
  • 提高单请求 token/s

KV Cache 的代价

  • 非常吃显存
  • 上下文越长、batch 越大、层数越多,缓存越膨胀
  • 容易形成显存碎片
  • 不合理的 cache 回收会导致抖动甚至 OOM

KV Cache 的容量可以粗估:

KV Cache 显存 ≈ batch × seq_len × num_layers × hidden_size × 2(K/V) × bytes_per_element

更精确一点,常常还会乘上 attention head 的拆分结构,但用于容量规划,这个数量级已经够用了。

2. 量化:不是“必快”,而是“更省、更可能快”

量化的核心是把权重从高精度表示变成低精度表示,例如:

  • FP16 / BF16
  • INT8
  • INT4

量化带来的收益主要有两类:

  1. 减少权重占用
  2. 降低带宽压力

但这里有个很容易踩的坑:
量化不一定带来线性加速。

原因通常有三个:

  • 推理瓶颈可能不在权重读取,而在 decode 串行依赖
  • 某些量化方案需要反量化开销
  • 内核没有针对目标硬件做优化,算子跑不满

所以,量化更准确的理解是:

  • 优先解决“放不下”
  • 其次解决“带宽受限”
  • 最后才是追求“绝对加速”

3. Prefill 与 Decode 要分开看

很多人拿一个“平均延迟”来评估模型服务,这其实很容易误导。

更合理的指标应该拆成:

  • TTFT(Time To First Token):首 token 延迟
  • TPOT(Time Per Output Token):每个输出 token 的平均耗时
  • Throughput:tokens/s 或 req/s
  • P95/P99:尾延迟

因为:

  • Prefill 决定 TTFT
  • Decode 决定持续生成速度

而 KV Cache、batching、量化,对这两个阶段的影响不一样。

sequenceDiagram
    participant U as User
    participant S as Inference Server
    participant G as GPU

    U->>S: 发送请求(prompt)
    S->>G: Prefill计算
    G-->>S: 首次隐藏状态 + KV Cache
    S-->>U: 首token

    loop Decode
        S->>G: 基于KV Cache生成下一个token
        G-->>S: token logits
        S-->>U: 流式返回token
    end

4. 并发调度:真正决定线上体验的关键

单请求跑得快,不代表系统跑得好。
线上最常见的瓶颈,反而是调度不合理

常见几种调度方式:

静态批处理

攒一批请求再一起算。

优点:

  • 实现简单
  • 吞吐较高

缺点:

  • 等待时间增加
  • 不适合流式输出
  • 长短请求混合时容易拖尾

Continuous Batching

请求不是“一批进一批出”,而是在 decode 过程中动态加入/退出 batch。

优点:

  • GPU 利用率更高
  • 更适合流式推理
  • 吞吐和延迟平衡更好

缺点:

  • 实现复杂
  • 对 KV Cache 管理要求高
  • 调度策略不合理时容易抖

长短请求隔离

把请求按 prompt 长度、预估生成长度分桶:

  • short queue
  • medium queue
  • long queue

这样短请求不会被长请求拖垮,P99 会更稳定。


方案对比与取舍分析

1. 几种优化手段的实际侧重点

手段主要收益主要代价适用场景
KV Cache降低 decode 重复计算显存占用高长文本、多轮对话
INT8 量化节省显存,通常较稳精度可能轻微波动通用线上部署
INT4 量化显存压缩更明显精度/兼容性风险更高显存紧张、成本敏感
动态批处理提升吞吐延迟变复杂中高并发服务
Continuous Batching吞吐与流式兼顾系统实现复杂在线推理平台
请求分桶改善尾延迟调度复杂度提升长短请求混杂明显

2. 一个比较务实的落地顺序

如果你是从零开始做推理服务,我建议顺序不要反:

  1. 先做观测:TTFT、TPOT、tokens/s、显存占用、队列长度
  2. 再上 KV Cache:先把 decode 重复计算降下来
  3. 再做量化:优先解决显存压力
  4. 再做动态批处理/连续批处理
  5. 最后做复杂调度:例如多队列、优先级、抢占

很多团队一上来就做花哨调度,结果连瓶颈在 prefill 还是 decode 都没分清,这很容易南辕北辙。


容量估算:部署前至少要算明白这几件事

1. 权重占用估算

如果模型参数量为 N,不同精度下显存大致为:

  • FP16/BF16:2 * N bytes
  • INT8:1 * N bytes
  • INT4:0.5 * N bytes

例如一个 7B 模型:

  • FP16:约 14GB
  • INT8:约 7GB
  • INT4:约 3.5GB

这只是权重,不含 KV Cache、激活、框架开销。

2. KV Cache 占用估算

粗估公式:

KV显存 ≈ batch_size × seq_len × num_layers × hidden_size × 2 × bytes

例如:

  • batch = 8
  • seq_len = 4096
  • layers = 32
  • hidden_size = 4096
  • bytes = 2(FP16)

数量级会非常可观,这也是为什么很多线上服务明明权重能放下,一开长上下文就 OOM。

3. 并发能力估算

可服务并发不是看“机器有几张卡”,而是看:

  • 每个请求的平均输入长度
  • 平均输出长度
  • KV Cache 生命周期
  • 流式连接持续时间
  • 目标 P95/P99

一个经验原则是:

  • 短文本问答 可以更激进地批处理
  • 长上下文生成 必须保守设置最大并发
  • 多轮对话 要重点关注 session 级 cache 保留策略

实战代码(可运行)

下面我用一个简化版推理服务模拟器演示三件事:

  1. KV Cache 如何影响延迟
  2. 动态批处理如何影响吞吐
  3. 长短请求混跑时为什么要分桶

这不是一个真正的 GPU 推理引擎,但它能帮助你把调度思路跑通。代码可直接运行。

1. 请求与推理模拟器

import time
import random
import threading
import queue
from dataclasses import dataclass, field
from typing import List, Dict


@dataclass
class Request:
    req_id: str
    prompt_len: int
    output_len: int
    use_kv_cache: bool = True
    created_at: float = field(default_factory=time.time)
    first_token_at: float = 0.0
    finished_at: float = 0.0
    generated: int = 0


class InferenceSimulator:
    """
    一个简化版大模型推理模拟器:
    - prefill成本 ~ prompt_len
    - decode成本 ~ output_len
    - KV Cache开启后,decode每步成本更低
    """
    def __init__(self, batch_size=4, continuous_batching=True):
        self.batch_size = batch_size
        self.continuous_batching = continuous_batching
        self.pending = queue.Queue()
        self.running: List[Request] = []
        self.finished: List[Request] = []
        self.lock = threading.Lock()
        self.stop_flag = False

    def submit(self, req: Request):
        self.pending.put(req)

    def _simulate_prefill(self, req: Request):
        # 假设每个输入token预填充开销 0.8ms
        time.sleep(req.prompt_len * 0.0008)

    def _simulate_decode_step(self, req: Request):
        # 有KV Cache时,每个token 1.5ms;没有时,近似随上下文增长
        if req.use_kv_cache:
            time.sleep(0.0015)
        else:
            dynamic_cost = 0.0015 + (req.prompt_len + req.generated) * 0.00002
            time.sleep(dynamic_cost)

    def scheduler_loop(self):
        while not self.stop_flag:
            # 补充batch
            while len(self.running) < self.batch_size and not self.pending.empty():
                req = self.pending.get()
                self._simulate_prefill(req)
                self.running.append(req)

            if not self.running:
                time.sleep(0.01)
                continue

            # decode一轮
            current_batch = list(self.running)
            for req in current_batch:
                self._simulate_decode_step(req)
                req.generated += 1
                if req.generated == 1:
                    req.first_token_at = time.time()
                if req.generated >= req.output_len:
                    req.finished_at = time.time()
                    with self.lock:
                        self.finished.append(req)
                    self.running.remove(req)

            if not self.continuous_batching:
                # 静态批处理:直到当前批次全部完成再补新请求
                while self.running:
                    current_batch = list(self.running)
                    for req in current_batch:
                        self._simulate_decode_step(req)
                        req.generated += 1
                        if req.generated == 1:
                            req.first_token_at = time.time()
                        if req.generated >= req.output_len:
                            req.finished_at = time.time()
                            with self.lock:
                                self.finished.append(req)
                            self.running.remove(req)

    def start(self):
        self.thread = threading.Thread(target=self.scheduler_loop, daemon=True)
        self.thread.start()

    def stop(self):
        self.stop_flag = True
        self.thread.join(timeout=1)

    def stats(self) -> Dict[str, float]:
        with self.lock:
            if not self.finished:
                return {}
            ttfts = [r.first_token_at - r.created_at for r in self.finished]
            totals = [r.finished_at - r.created_at for r in self.finished]
            total_tokens = sum(r.output_len for r in self.finished)
            total_time = max(r.finished_at for r in self.finished) - min(r.created_at for r in self.finished)
            return {
                "requests": len(self.finished),
                "avg_ttft": sum(ttfts) / len(ttfts),
                "avg_latency": sum(totals) / len(totals),
                "throughput_tokens_per_sec": total_tokens / total_time if total_time > 0 else 0.0,
            }


def build_requests(n=12, kv_cache=True):
    reqs = []
    for i in range(n):
        prompt_len = random.choice([64, 128, 256, 1024, 2048])
        output_len = random.choice([32, 64, 128])
        reqs.append(
            Request(
                req_id=f"req-{i}",
                prompt_len=prompt_len,
                output_len=output_len,
                use_kv_cache=kv_cache
            )
        )
    return reqs


def run_experiment(title, continuous_batching, kv_cache):
    print(f"\n=== {title} ===")
    sim = InferenceSimulator(batch_size=4, continuous_batching=continuous_batching)
    sim.start()

    for req in build_requests(12, kv_cache=kv_cache):
        sim.submit(req)

    while True:
        stats = sim.stats()
        if stats.get("requests", 0) >= 12:
            break
        time.sleep(0.1)

    sim.stop()
    print(stats)


if __name__ == "__main__":
    random.seed(42)

    run_experiment(
        title="实验1:连续批处理 + 开启KV Cache",
        continuous_batching=True,
        kv_cache=True
    )

    run_experiment(
        title="实验2:连续批处理 + 关闭KV Cache",
        continuous_batching=True,
        kv_cache=False
    )

    run_experiment(
        title="实验3:静态批处理 + 开启KV Cache",
        continuous_batching=False,
        kv_cache=True
    )

2. 运行后你应该观察什么

执行:

python simulator.py

你通常会看到类似趋势:

  • 开启 KV Cache 时,平均延迟和 token 吞吐明显更好
  • 关闭 KV Cache 时,随着生成变长,decode 成本逐渐抬升
  • 连续批处理 通常比静态批处理吞吐更高,且对流式场景更友好

这段代码虽然简化,但非常适合拿来做团队内部认知统一:
别只看单请求速度,要看队列、batch、KV Cache 一起作用后的整体表现。


一个更贴近工程的服务骨架

下面再给一个 FastAPI 的最小服务骨架,用于演示“长度分桶 + 简单调度”的做法。

from fastapi import FastAPI
from pydantic import BaseModel
import threading
import queue
import time
import uuid

app = FastAPI()

short_q = queue.Queue()
long_q = queue.Queue()
results = {}

class GenerateRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 64

def estimate_prompt_len(text: str) -> int:
    # 简化估算:中文/英文真实token数需要依赖tokenizer
    return max(1, len(text) // 4)

def fake_infer(prompt_len: int, max_new_tokens: int):
    # 模拟 prefill + decode
    time.sleep(prompt_len * 0.0005)
    time.sleep(max_new_tokens * 0.001)
    return "hello world"[:min(11, max_new_tokens)]

def worker(q: queue.Queue, worker_name: str):
    while True:
        item = q.get()
        if item is None:
            break
        req_id, prompt_len, max_new_tokens = item
        try:
            output = fake_infer(prompt_len, max_new_tokens)
            results[req_id] = {
                "worker": worker_name,
                "output": output,
                "status": "done"
            }
        except Exception as e:
            results[req_id] = {
                "worker": worker_name,
                "status": "error",
                "error": str(e)
            }
        finally:
            q.task_done()

threading.Thread(target=worker, args=(short_q, "short-worker"), daemon=True).start()
threading.Thread(target=worker, args=(long_q, "long-worker"), daemon=True).start()

@app.post("/generate")
def generate(req: GenerateRequest):
    req_id = str(uuid.uuid4())
    prompt_len = estimate_prompt_len(req.prompt)

    results[req_id] = {"status": "queued"}

    if prompt_len <= 256:
        short_q.put((req_id, prompt_len, req.max_new_tokens))
        lane = "short"
    else:
        long_q.put((req_id, prompt_len, req.max_new_tokens))
        lane = "long"

    return {"request_id": req_id, "lane": lane}

@app.get("/result/{req_id}")
def get_result(req_id: str):
    return results.get(req_id, {"status": "not_found"})

运行:

uvicorn app:app --reload

这个示例的意义不是“直接上线”,而是说明一个实用原则:

请求分桶是最容易见效的调度优化之一。

尤其是当你的流量里既有几十 token 的轻问答,也有几千 token 的长上下文摘要时,不做隔离,短请求体验通常会很差。


调度设计建议:从“先进先出”进化到“有策略地公平”

真实系统里,我比较推荐以下调度思路:

基础版

  • FIFO 队列
  • 限制最大 batch size
  • 限制最大上下文长度
  • 统一超时与取消

适合最初版本,优点是简单、稳定。

进阶版

  • 按 prompt 长度分桶
  • Continuous Batching
  • 每轮 decode 允许新请求插入
  • 设置 batch token budget,而不是只设 batch request count

这里“token budget”很重要,因为 1 个 4k 输入请求和 1 个 64 输入请求,资源消耗根本不是一个量级。

稳定性优先版

  • 分层优先级队列
  • 长请求隔离到独立池
  • 设置用户级并发上限
  • 对异常长会话做 cache 回收
  • 必要时做 admission control(准入控制)

常见坑与排查

这部分我尽量写得接地气一点,都是实际很容易遇到的问题。

坑 1:量化后显存降了,但速度几乎没提升

原因

  • 当前瓶颈在 decode 串行,不在权重加载
  • 量化 kernel 没有吃满硬件能力
  • 小 batch 场景下,量化收益被调度与框架开销抵消

排查方式

  • 分别统计 prefill 和 decode 的耗时
  • 观察 GPU utilization 与 memory bandwidth
  • 对比 FP16 与 INT8 的 TTFT、TPOT,而不是只看总延迟

建议

  • 先确认你的瓶颈是“算力受限”还是“带宽受限”
  • 线上优先用成熟量化方案,不要只看理论压缩率

坑 2:开了 KV Cache,结果更容易 OOM

原因

  • 长上下文请求太多
  • cache 生命周期过长,没有及时释放
  • session 保活策略激进
  • 显存碎片化严重

排查方式

  • 统计每个请求的 prompt_len / output_len
  • 记录 session 级 KV Cache 保留时长
  • 观察 OOM 前显存曲线是否持续增长
  • 区分“峰值申请过大”还是“碎片不足导致无法分配”

建议

  • 设置最大上下文长度
  • 对空闲 session 的 KV Cache 做 TTL 回收
  • 使用显存池或分页式 cache 管理
  • 长短请求分池,不要混用同一资源区

坑 3:吞吐上去了,但 P99 非常差

原因

  • batch 太激进
  • 长请求拖累短请求
  • 队列堆积造成排队时间增加
  • 没有取消机制,僵尸连接占住资源

排查方式

  • 区分排队耗时、prefill 耗时、decode 耗时
  • 观察队列长度与 P99 是否同步抬升
  • 抽样分析慢请求的 prompt 长度分布

建议

  • 优先控制尾延迟,不要盲目追求吞吐
  • 加长度分桶与优先级
  • 给用户侧设置合理的最大生成长度

坑 4:短文本服务挺稳,一上长上下文就抖

原因

  • 容量估算基于平均输入长度,没考虑长尾
  • prefill 计算骤增
  • KV Cache 膨胀导致并发能力突然下降

建议

  • 容量规划不能只看平均值,要看 P95 输入长度
  • 把长上下文请求单独路由到专用实例
  • 降低长上下文实例的并发上限

安全/性能最佳实践

这里把“安全”和“性能”放一起讲,因为线上系统的稳定性,本身就是一种工程安全。

1. 设置硬性边界,不要相信所有输入

至少要限制:

  • 最大 prompt token 数
  • 最大生成 token 数
  • 最大并发连接数
  • 单用户/单租户速率
  • 单请求最长执行时间

如果这些边界不设,恶意或异常请求会迅速放大 KV Cache 和队列压力。

2. 流式连接必须支持取消

用户断开连接后,如果服务端还继续生成:

  • GPU 资源被白白占用
  • KV Cache 无法及时释放
  • 队列吞吐下降

这是个很常见但经常被忽视的问题。我当时就踩过这个坑:前端页面关了,后端还在“努力生成”,机器看着很忙,业务却没收益。

3. 指标一定要分阶段采集

至少监控这些指标:

  • 请求数、成功率、超时率、取消率
  • TTFT、TPOT、总延迟
  • prefill 耗时、decode 耗时
  • 每秒生成 token 数
  • batch 大小分布
  • prompt 长度分布、输出长度分布
  • GPU 利用率、显存占用、显存碎片
  • KV Cache 命中/占用/回收情况

4. 优先做“稳定提速”,而不是“极限提速”

例如:

  • 先用成熟 INT8,再考虑更激进 INT4
  • 先做长度分桶,再做复杂抢占
  • 先做静态资源隔离,再做动态混部

线上系统最怕的是“实验室里很快,业务峰值下不稳”。

5. 做好降级策略

建议准备至少三档:

  • 正常模式:全功能、流式、较高并发
  • 拥塞模式:降低最大生成长度、减少 batch
  • 保护模式:拒绝长请求,只保留核心业务

这样当显存逼近上限、队列堆积严重时,系统不会一下子雪崩。

stateDiagram-v2
    [*] --> Normal
    Normal --> Busy: 队列长度升高/P95恶化
    Busy --> Protected: 显存紧张/OOM风险
    Protected --> Busy: 压力下降
    Busy --> Normal: 指标恢复

一个可执行的落地清单

如果你下周就要开始优化现网服务,可以按这个顺序推进:

第一步:先量化问题,而不是先量化模型

先回答四个问题:

  1. 现在慢在 TTFT 还是 TPOT?
  2. 瓶颈在 prefill、decode,还是排队?
  3. 显存压力来自权重还是 KV Cache?
  4. P99 是被长请求拖高,还是 batch 太激进?

第二步:做最小收益闭环

推荐组合:

  • 开启 KV Cache
  • 增加长度分桶
  • 限制最大上下文与最大生成长度
  • 加基础监控

这一套通常已经能解决大部分“能跑但不好用”的问题。

第三步:再上量化和连续批处理

前提是你已经有:

  • 稳定的 benchmark
  • 能分阶段观测延迟
  • 能比较不同配置下的 P50/P95/P99

否则优化结果很容易“看起来更快,实际上更不稳”。


总结

大模型推理优化,表面看是在调参数,实际上是在做一套系统工程。

可以把关键结论记成三句话:

  1. KV Cache 解决的是 decode 重复计算,但会把显存管理推到前台
  2. 量化首先解决“放不下”和“带宽贵”,其次才是纯提速
  3. 并发调度决定线上体验上限,尤其决定 P95/P99 是否可控

如果你希望得到一条比较稳妥的落地路径,我建议是:

  • 先分清 prefill 与 decode 的瓶颈
  • 再启用 KV Cache 并做容量估算
  • 然后通过量化释放显存空间
  • 最后用长度分桶与 Continuous Batching 提升整体吞吐

别一开始就追求“最先进”的架构。
在推理系统里,可解释、可观测、可回退,往往比“理论最快”更重要。

真正好的推理平台,不是 benchmark 上最漂亮的那一个,而是业务高峰来临时,仍然能稳稳接住流量的那一个。


分享到:

上一篇
《从 0 到 1:基于开源项目搭建企业级内部知识库的实战指南》
下一篇
《自动化测试中的稳定性治理实战:从脆弱用例定位到 CI 误报率下降策略》