sglang 源码学习笔记(二)- backend & forward 过程



作者丨进击的Bruce
来源丨https://zhuanlan.zhihu.com/p/18285771025
编辑丨GiantPandaLLM


前言

书接上回,Bruce 仗剑走天涯:sglang 源码学习笔记(一)- Cache、Req与Scheduler (https://zhuanlan.zhihu.com/p/17186885141)在上一篇文章中,我们介绍了sglang forward前的行为。本次我们详细解析forward 这个核心实现的全流程。

但首先我们回顾一下forward 的传递过程,也就是下面这张图。

从这里,我们可以看到关键的推理过程,batch 是怎么传导进backend的,kvcache 是在哪里分配的以及如何被使用,这里我们需要记住req_pool_indice和out_cache_loc 这两个关键参数,他们是backend访问kvcache(memory_pool)的关键。

结合社区画的全流程图,我想读者对于sglang 推理的大体流程栈,已经有了一定的认识。

一句话,scheduler 为每个到来的请求,分配req_to_token_pool的slot和token_to_kv_pool的slot,随后由于不同请求的prefix cache匹配情况和推理类型,设置cache的相关字段(prefix tokens,extend tokens,等等),这些信息跟随batch 流入attentionbackend进行真正的forward。

这里为了方便理解,我们选择最常用的组合(MHA,decode-only,autogressive decoding),backend 选择flashinfer。

这一部分的代码主要位于python/sglang/srt/layers(https://github.com/sgl-project/sglang/tree/b5fb4ef58a6bbe6c105d533b69e8e8bc2bf4fc3c/python/sglang/srt/layers)和flahsinfer(https://github.com/flashinfer-ai/flashinfer/tree/9f5fbee3230136b0ccf4a88938d0e244dcaf4b26)(没错,我要补上一篇留下的坑了)。

为了解释逻辑更加清晰,我会先讲attentionBackend的部分,再去讲cudaGraphRunner的部分,cudaGraph 本身是个异步overlap cpu的策略,与主体推理逻辑基本正交,而且确实比较复杂。

接下来进入正文。

AttentionBackend

forward 的推理堆栈上文已经提过,穿过ModelRunner 后,基本是如下的过程,其中attentionBackend是其中的关键角色。

ModelRunner->Model->layer->attentionBackend

通用数据结构

首先,我们仔细看看attentionBackend的数据结构。我们选择flashinferBackend 作为example。

# AttentionBackend 是所有具体backend 实现的基类, 这里可以观察到sglang 的attention对cuda graph 有比较强的依赖
# 这也是sglang overlap cpu & gpu excution的策略之一
class AttentionBackend(ABC):
     def init_forward_metadata(self, forward_batch: ForwardBatch):
         """Init the metadata for a forward pass."""
     def init_cuda_graph_state(self, max_bs: int):
         """Init the global shared states for cuda graph."""
     def init_forward_metadata_capture_cuda_graph(....):
         """Init the metadata for a forward pass for capturing a cuda graph."""
     def init_forward_metadata_replay_cuda_graph(....):
         """Init the metadata for a forward pass for replying a cuda graph."""
     def get_cuda_graph_seq_len_fill_value(self):
         """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" 
     def forward(...):
        """Run forward on an attention layer."""
        if forward_batch.forward_mode.is_decode():
            return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
        else:
            return self.forward_extend()

cuda graph的部分,我们下个部分再看,去掉cuda graph的部分,一个attentionBackend 需要至少两个接口:

init_forward_metadata与 forward。而这俩接口就是实现一个forward 最基本的接口,一个设置控制信息,一个进行推理。上篇我们介绍过一个decode forward函数,再看一遍。

def forward_decode(self, forward_batch: ForwardBatch):
        self.attn_backend.init_forward_metadata(forward_batch)
        return self.model.forward(
            forward_batch.input_ids, forward_batch.positions, forward_batch
        )

一目了然。介绍完了基类,具体类,比如flashinferBackend 是怎么样呢?

flashinfer 是什么

下图是flashinfer paper中给的系统框架图,实际上实现比这个复杂,但是我们依旧可以获得一些总体的组件。

左半部分主要是sglang中的实现,有关pagtable和radix tree读过上文有关cache和scheduler的部分,相信大家有所理解了,这里的pagetable 和 pageattention 比较像,但是实际上sglang的实现是token-table和token attention的方式。radix tree 就是用于prefix match以方便共享cache,减少计算量。最后是block-sparse,将kvcache 表示成稀疏的格式表达。这就是进推理前,flashinfer & sglang 对kvcache的操作。右边,我们主要关心runtime scheduler,其实主要是tiling 过程,将当前batch tile到最佳CTA 配置上执行。JIT compiler 用于客制化,不用JIT 也有prebuilt的ops。

好,接下来让我们深入实现细节,看看flashinfer 内部实现。

Wrapper 是什么

让我们先介绍wrapper。

看过flashinfer 代码的朋友都会留意到wrapper 这个结构,我们可以简单认为是子任务的调度器/执行器,是flashinfer 对底层的封装。先来两张图,这里简单说明了一个forward 中flashbackend的内部结构和互动关系。整体来说,wrapper 是flashinfer 对底层执行的封装,与forward mode 对应,分为decode和prefill wrapper 两大类。与传统的执行流程不同,flashinfer 将配置和执行分成了两个步骤,先通过init_forward_metadata 配置硬件需要的参数(k,v cache的indice,query的indice),而后调用forward 接口进行正式推理。注意这两个步骤目前没有明显的overlap,所以应该是工程上拆成的两份,目前唯有一些很轻的memcpy(host to device)在这里overlap 了起来。

flashinfer 的完整初始化流程

好,现在我们看看flashinferBackend 其他的成员变量。

        self.decode_use_tensor_cores = should_use_tensor_core(
            kv_cache_dtype=model_runner.kv_cache_dtype,
            num_attention_heads=model_runner.model_config.num_attention_heads
            // get_attention_tp_size(),
            num_kv_heads=model_runner.model_config.get_num_kv_heads(
                get_attention_tp_size()
            ),
        )

首先是判断decode 过程需不需要使用tensor core,这个判断基于环境变量,数据类型和qga group。判断如下:

  1. 如果用户设置使用tensor core,则返回true

  2. 如果是float8_e4m3fn或者float8_e5m2,返回true

  3. 如果是float16,half或者bfloat16,且gqa group size 超过4,返回true

  4. 其他情况,不使用tensor core

    env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
    if env_override is not None:
        return env_override.lower() == "true"
    gqa_group_size = num_attention_heads // num_kv_heads

    # Determine based on dtype and GQA group size
    if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
        return True
    elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
        return gqa_group_size > 4
    else:
        return False

也就是说,默认情况下,对于一个fp16或者bf16的 MHA,sglang 只会用cudacore 进行计算。我猜测原因可能是因为MHA下decode 的token 实际只有1,每个head的计算量其实是相对偏小的。但是为什么fp8 的decode 适合tensor core算呢?从官方说明上看,e4m3以及e5m2 tensor core 明确支持,但cuda core 支持并不好,所以是可以理解的。BTW,从 zihao 那里得知,其实sglang 一开始cuda core 是按fp16/bf16 调的,后面模板没有维护,所以fp8 目前只支持tensorcore,另外从当前attention 层的计算规模看,后面主要会是tensor core 计算,所以也没有后面花精力调cuda core decoding的打算。

接下来,设置最大上下文长度,和所需wrapper 个数。

self.max_context_len = model_runner.model_config.context_len
self.num_wrappers = 1
self.dispatch_reason = None   # 在sliding window或者cross attention下,wrappers不同,但我们先不关注

接下来,初始化内部资源,首先是workspace,注意这里的dtype 加剧用于分配buffer,没有其他用处,实际上最后这里最后会用于存储一些计算的中间结果。默认的workspace buffer 是384M,在Qwen2ForCausalLM下是512M。

        # Allocate buffers
        self.workspace_buffer = torch.empty(
            global_config.flashinfer_workspace_size,
            dtype=torch.uint8,
            device=model_runner.device,
        )

########## in global config######################################
        self.flashinfer_workspace_size = os.environ.get(
            "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
        )

随后是一些内部关键数组。其中kv_indptr与 qo_indptr是一个wrapper 一个,kv_last_page_len 是一个backend 一个。这些数据结构都是flashinfer 内不需要用的。

        max_bs = model_runner.req_to_token_pool.size
        self.kv_indptr = [
            torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
            for _ in range(self.num_wrappers)
        ]
        self.kv_last_page_len = torch.ones(
            (max_bs,), dtype=torch.int32, device=model_runner.device
        )
        self.qo_indptr = [
            torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
            for _ in range(self.num_wrappers)
        ]

其实从这里已经可以嗅到稀疏化的味道,眼疾的朋友,应该会联想到稀疏化的格式表达。indptr 通常数组用于指示每一行(或列)中非零元素的起始和结束位置。对于kv cache 来说,向gpu 刻画一个 稀疏矩阵,需要kv_indptr,kv_indices以及kv buffer。而对于query/output来说,由于query 是一个ragged tensor,所以就只需要indptr。至于这里为什么是max_bs+1,是由于indptr[0]是辅助用的,真正存储数据的是indptr [1:max_bs]。

kv_last_page_len 这个数组我们也需要解释一下。这里的page len的单位是token num,也就是一个page 存几个token 对应cache的意思,不过实际使用上,sglang的page size都是1,也就是一个page 一个token的cache。这里的page 和 实际物理页没有关系,只是一个内存的管理单元,仅仅在碎片管理的角度有意义。

接下来,创建核心子执行器wrapper, 由于我们这里只考虑decode-only selfattention的情况,所以如下的num_wrapper 为1。

        # Create wrappers
        # NOTE: we do not use ragged attention when there are multiple wrappers
        self.prefill_wrapper_ragged = (
            BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
            if self.num_wrappers == 1
            else None
        )

        # Two wrappers: one for sliding window attention and one for full attention.
        # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
        self.prefill_wrappers_paged = []
        self.prefill_wrappers_verify = []
        self.decode_wrappers = []
        for _ in range(self.num_wrappers):
            self.prefill_wrappers_paged.append(
                BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
            )
            self.prefill_wrappers_verify.append(
                BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
            )
            self.decode_wrappers.append(
                BatchDecodeWithPagedKVCacheWrapper(
                    self.workspace_buffer,
                    "NHD",
                    use_tensor_cores=self.decode_use_tensor_cores,
                )
            )

这里实际上包括四种wrapper,具体wrapper 内部的数据结构和初始化,我们放到最后看。

wrapper类别 wrapper 实现 场景
prefill_wrapper_ragged
BatchPrefillWithRaggedKVCacheWrapper
Prefill使用变长tensor(定长不够用)
prefill_wrappers_paged
BatchPrefillWithPagedKVCacheWrapper
定长tensor 的Prefill
prefill_wrappers_verify
BatchPrefillWithPagedKVCacheWrapper
投机推理用
decode_wrappers
BatchDecodeWithPagedKVCacheWrapper
decode 用

随后就是对wrapper 输入的更新器,这里只有两类,Prefill和Decode。

        # Create indices updater
        self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
        self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
            model_runner, self
        )

这里没有什么可以特别说明的,updater的初始化,基本上就是把attentionbackend和model_runner 里的结构赋值过去。读者可以通过下面的例子感受一下,dispatch部分我们只看update_single_wrapper 即可。

        self.num_qo_heads = (
            model_runner.model_config.num_attention_heads // get_attention_tp_size()
        )
        self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
            get_attention_tp_size()
        )
        self.head_dim = model_runner.model_config.head_dim
        self.data_type = model_runner.kv_cache_dtype
        self.q_data_type = model_runner.dtype
        self.sliding_window_size = model_runner.sliding_window_size
        self.attn_backend = attn_backend

        # Buffers and wrappers
        self.kv_indptr = attn_backend.kv_indptr
        self.kv_last_page_len = attn_backend.kv_last_page_len
        self.qo_indptr = attn_backend.qo_indptr
        self.req_to_token = model_runner.req_to_token_pool.req_to_token
        self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged

        # Dispatch the update function
        if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
            self.update = self.update_sliding_window
        elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
            self.update = self.update_cross_attention
        else:
            assert self.attn_backend.num_wrappers == 1
            self.update = self.update_single_wrapper

另外需要说明的是forward_metadata,forward metadata 和 updater 对应,包括两种,由于sglang scheduler 本身同时只会有一个forward batch,所以只需要一份forward metadata 即可。

        # Other metadata
        self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
        self.decode_cuda_graph_metadata = {}
        self.prefill_cuda_graph_metadata = {}

这里的forward_metadata 其实主要是Wrapper 和周围参数,每次forward 都会在metadata里指定wrapper。

@dataclassclassDecodeMetadata:
    decode_wrappers:List[BatchDecodeWithPagedKVCacheWrapper]@dataclassclassPrefillMetadata:
    prefill_wrappers:List[BatchPrefillWithPagedKVCacheWrapper]
    use_ragged:bool
    extend_no_prefix:bool

wrapper 的数据结构和初始化过程

首先是检查入参,kv layout 只支持NHD和HND。

        _check_kv_layout(kv_layout)

有关jit 的功能, 我们先不管。所以直接看内部新引入的资源,以下以BatchPrefillWithPagedKVCacheWrapper 为例。

        ## 前三行就是入参赋值,没啥好说的
        self._kv_layout = kv_layout
        self._float_workspace_buffer = float_workspace_buffer
        self.device = float_workspace_buffer.device
        ## 这里的backend 更多程度上其实是底层实现用flashattention2还是flashattention3
        ## auto 即自动识别,如果当前硬件支持fa3则用fa3,否则用fa2,而对于fa3
        ## vector_sparse 可以理解为一种中间形式的稀疏表达,之所以fa3 需要存储这个中间数组
        ## 是由于考虑到fa3下将vector_sparse 存放于GPU寄存器的话,寄存器不够
        if backend in ["fa3", "auto"]:
            # NOTE(Zihao): assume maximum accumulate kv length is 16M
            self._vector_sparse_indices_buffer = torch.empty(
                (16 * 1024 * 1024,), dtype=torch.int32, device=self.device
            )
            # NOTE(Zihao): assume maximum batch size is 32768
            self._vector_sparse_indptr_buffer = torch.empty(
                (32768,), dtype=torch.int32, device=self.device
            )

        ## kv_lens_buffer 实际是请求对应的kv cache len的长度,单位也是tokens num
        self._kv_lens_buffer = torch.empty(
            (32768,), dtype=torch.int32, device=self.device
        )
        
        ## 如下两个都是存储控制信息的buffer,区别在于_int_workspace_buffer是device 侧的buffer
        ## _pin_memory_int_workspace_buffer是host 侧的buffer,二者通过cudaMemcpyAsync 互相交互
        ## 后面我们会看到这个结构就是init_forward_metadata的核心。
        self._int_workspace_buffer = torch.empty(
            (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
        )
        self._pin_memory_int_workspace_buffer = torch.empty(
            self._int_workspace_buffer.shape,
            dtype=self._int_workspace_buffer.dtype,
            device="cpu",
            pin_memory=True,
        )

以上的资源属于wrapper 内部的核心资源,最后就是将attentionbackend 引用进来方便访问。

        self._qo_indptr_buf = qo_indptr_buf
        self._paged_kv_indptr_buf = paged_kv_indptr_buf
        self._paged_kv_indices_buf = paged_kv_indices_buf
        self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
        self._backend = backend

        ## 以下主要是cuda graph使用参数,我们放到cuda graph 模式讲。
        self._custom_mask_buf = custom_mask_buf
        self._mask_indptr_buf = mask_indptr_buf
        self._max_total_num_rows = None

而 decode wrapper 也比较类似,但是没有了其中一部分,注意我们这里也暂时屏蔽了cuda graph相关的实现。

可以看到,decode 下,query 相关的结构不见了(_qo_indptr_buf与_kv_lens_buffer),对于decode,都是one-by-one 的输出,query相关的内容本身也已经在gpu cache上,不需要额外传入(但是cuda graph 模式下也有query 相关结构,我们再解析)。另外,由于fa3 主要是在fa2 基础上加了seq parallel,只影响prefill,所以decode 这边不需要vector_sparse 这个中间层的buffer。

        self._kv_layout = kv_layout
        self._float_workspace_buffer = float_workspace_buffer
        self.device = float_workspace_buffer.device
        self._int_workspace_buffer = torch.empty(
            (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device
        )
        self._pin_memory_int_workspace_buffer = torch.empty(
            (8 * 1024 * 1024,),
            dtype=torch.uint8,
            pin_memory=True,
            device="cpu",
        )
        self._fixed_batch_size = 0
        self._paged_kv_indptr_buf = paged_kv_indptr_buffer
        self._paged_kv_indices_buf = paged_kv_indices_buffer
        self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer
        self._use_tensor_cores = use_tensor_cores
        self._use_cuda_graph = use_cuda_graph

本节最后,简单说明一下wrapper 的两个主要接口的功能。

1.plan forward 过程的第一步,将控制信息写入gpu

2. run forward过程第二步,从model 里forward会获得更新后的kvcache(有些forward 会在model 层进行kvcache的更新,比如deepseek),此时调用wrapper的run进行low-level 的run。

下图就是forward 过程中各结构接口的调用关系。

Plan Info 是什么

另外,我们还需要介绍一下plan info。这是wrapper的核心数据结构之一,属于运行的配置信息,plan info即flashinfer 计划为本次forward 提供的配置信息。数据结构如下:

struct PrefillPlanInfo {
  int64_t padded_batch_size;                    # batch size, 和forward batchsize有一丝区别
                                                # forwardbatch的batchsize 是当前batch里请求的个数
                                                # padded_batch_size 可能比forward size大,它面向GPU CTA
                                                # 每个CTA 需要计算的tile根据请求情况获得,padding_batchsize 根据tile 计算
  int64_t total_num_rows;                       # 当前batch 处理总输入token长度, 对应sum(qo_indptr)
  int64_t total_num_rows_offset;                # 对应qo_indptr的数组指针
  int64_t cta_tile_q;                           # 一个CTA 负责处理的query 长度,即tile 后的query 长度, 下面假设query 0 input 被tile成三个tile
  int64_t request_indices_offset;               # tile 后的request index 数组指针,like [0, 0, 0]
  int64_t qo_tile_indices_offset;               # tile 后的query index 数组指针,like 请求0 被tile 成三份,like[0, 1, 2]
  int64_t kv_tile_indices_offset;               # tile 后的kv index 数组指针,如果kv chunksize > need_kv_len, 则为[0, 0, 0]
  int64_t merge_indptr_offset;                  # merge indptr 与tile 无关,是与请求和gqa有关的
                                                # 如果模型的group size为4,则一个请求对应四个merge_indptr项, 比如[100,200,300,400]
  int64_t o_indptr_offset;                      # 一个请求一个,值为对齐到tile_kv_len * group_size(mha下为1)
  int64_t kv_chunk_size_ptr_offset;             # 这里也是个数组指针,但是数组size 为1,内容就是kv_chunk_size
  int64_t v_offset;                             # attention中间态计算结果,s_ = q*k, v_ = softmax(s_)*v
  int64_t s_offset;                             
  int64_t block_valid_mask_offset;              # 数组指针,数组内容是根据tile 分片后的block是不是有效的
                                                # 一般都是有效,但在cudagraph的使用下,会有不对齐的情况,以后再说
  bool enable_cuda_graph;                       # 使用cuda graph
  bool split_kv;                                # 是否进行了分片
}

plan info 实在是面向gpu的核心数据结构,这里才有了我们以往耳熟能详的tiling 过程。现在我们可以继续补充forward batch的流程图,forward batch 走进wrapper 里就是plan info了。

init_forward_meta 与 plan/prefill 举例

理解了整个初始化和主要的数据结构,接下来我们可以看看init_forward_meta的过程了。这里的核心就是wrapper的plan 接口。如下我列出了其中prefill 的branch的case(不考虑encoder-decoder 和 sliding window的实现)。

        def init_forward_metadata(self, forward_batch: ForwardBatch):
            prefix_lens = forward_batch.extend_prefix_lens
            # Some heuristics to check whether to use ragged forward
            # 如果有prefill token 太长的情况,使用ragged tensor
            if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
                use_ragged = True
                extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
            else:
                use_ragged = False
                extend_no_prefix = False 
            # 通过updater 更新 prefill wrapper
            self.indices_updater_prefill.update(
                forward_batch.req_pool_indices,
                forward_batch.seq_lens,
                forward_batch.seq_lens_sum,
                prefix_lens,
                prefill_wrappers=self.prefill_wrappers_paged,
                use_ragged=use_ragged,
                encoder_lens=forward_batch.encoder_lens,
                spec_info=None,
            )
            ## 更新forward_metadata
            self.forward_metadata = PrefillMetadata(
                self.prefill_wrappers_paged, use_ragged, extend_no_prefix
            )

这里ForwardBatch的字段内容可以参考Bruce 仗剑走天涯:sglang 源码学习笔记(一)- Cache、Req与Scheduler(https://zhuanlan.zhihu.com/p/17186885141)里的说明。其中最重要的是update接口, 这里最终会调用到wrapper的plan,栈如下。实际上begin_forward 就是 plan,指针是同一个。

FlashInferAttnBackend.init_forward_metadata->
    FlashInferIndicesUpdaterPrefill.update_single_wrapper->
        FlashInferIndicesUpdaterPrefill.call_begin_forward->
            BatchPrefillWithPagedKVCacheWrapper.begin_forward->
                BatchPrefillWithPagedKVCacheWrapper.plan

我们主要讲两个函数的实现,call_begin_forward与plan。

call_begin_forward

def call_begin_forward(
        self,
        wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,  # ragged wrapper, 用于输入较长的情况
        wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,    # 主要使用的wrapper
        req_pool_indices: torch.Tensor,                        # batch中包含的请求index
        paged_kernel_lens: torch.Tensor,                       # 请求的长度(对ragged 情况下,对应extend_prefix_lens)
        paged_kernel_lens_sum: int,                            # paged_kernel_lens 之和
        seq_lens: torch.Tensor,                                # 请求的完整长度
        prefix_lens: torch.Tensor,                             # extend_prefix_lens
        kv_start_idx: torch.Tensor,                            # 传参为一般None,实际上指各请求kv cache的起始index 
        kv_indptr: torch.Tensor,                               # attention backend 的kv 数组
        qo_indptr: torch.Tensor,                               # attention backend 的qo 数组
        use_ragged: bool,                                      # 是否使用ragged tensor用于query
        spec_info: Optional[SpecInfo],                         # 是否是投机推理
    ):

如上是call_begin_forward的传参,方便大家理解一些上下文。接下来是具体的实现。

        ## 获得当前batch的batchsize
        bs = len(req_pool_indices)
        if spec_info is None:
            # Normal extend
            # indptr 意为矩阵中每行非零值的起始位置,以下说明了每个请求的输出token存kv cache的位置
            kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
            kv_indptr = kv_indptr[: bs + 1]
            # 这里将分配一个数组,具体赋值在create_flashinfer_kv_indices_triton中
            kv_indices = torch.empty(
                paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
            )
            ## 注意,这是一个并行函数,并行度是bs,如下调用说明同时起了bs个trtion 内核执行,入参都一样
            create_flashinfer_kv_indices_triton[(bs,)](
                self.req_to_token,
                req_pool_indices,
                paged_kernel_lens,
                kv_indptr,
                kv_start_idx,
                kv_indices,
                self.req_to_token.shape[1],
            )
            ## 以下说明了每个请求的输出token的起始位置
            qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
            qo_indptr = qo_indptr[: bs + 1]
            custom_mask = None
        ## end_forward 接口已被废弃,可以忽略
        wrapper_paged.end_forward()
        ## begin_forward == plan接口
        wrapper_paged.begin_forward(
            qo_indptr,
            kv_indptr,
            kv_indices,
            self.kv_last_page_len[:bs],
            self.num_qo_heads,
            self.num_kv_heads,
            self.head_dim,
            1,
            q_data_type=self.q_data_type,
            custom_mask=custom_mask,
        )

begin forward 稍后分析,我们先看看create_flashinfer_kv_indices_triton,首先这是一个并行执行的函数,依赖triton的jit 。上述我们看到了调用方的调用代码,注意[bs,] 这个部分,这表明在0轴上起了bs 个triton 内核执行该函数,bs的并行度 最终体现在函数内,就是tl.program_id(axis=0)的返回值,该返回值为[0, bs-1],所以以下其实是对入参数组的并行访问。

@triton.jitdefcreate_flashinfer_kv_indices_triton(
    req_to_token_ptr,# [max_batch, max_context_len]
    req_pool_indices_ptr,
    page_kernel_lens_ptr,
    kv_indptr,
    kv_start_idx,
    kv_indices_ptr,
    req_to_token_ptr_stride:tl.constexpr,):
    BLOCK_SIZE:tl.constexpr=512
    pid=tl.program_id(axis=0)

    # batch 中第N个请求的req_pool_index和kv_indices_offset
    req_pool_index=tl.load(req_pool_indices_ptr+pid)
    kv_indices_offset=tl.load(kv_indptr+pid)

    kv_start=0
    kv_end=0
    ifkv_start_idx:
        kv_start=tl.load(kv_start_idx+pid).to(tl.int32)
        kv_end=kv_start
    ## 获得请求的kvcache 长度
    kv_end+=tl.load(page_kernel_lens_ptr+pid).to(tl.int32)

    num_loop=tl.cdiv(kv_end-kv_start,BLOCK_SIZE)
    foriinrange(num_loop):
        # block_size 又是一个并行度,意在加速load,store的并行效率
        # offset 返回的是一个BLOCK_SIZE 维度的array
        offset=tl.arange(0,BLOCK_SIZE)+i*BLOCK_SIZE
        mask=offset<kv_end-kv_start
        # 并行读取req_to_token pool 中req 对应的token 索引
        data=tl.load(
            req_to_token_ptr
            +req_pool_index*req_to_token_ptr_stride
            +kv_start
            +offset,
            mask=mask,
        )
        ## 并行写入kv indices数组
        ## 注意,这里kv_indices是临时结构,和token_to_kv_pool 没有关系,但最终会作为wrapper的入参
        tl.store(kv_indices_ptr+kv_indices_offset+offset,data,mask=mask)

backend CacheModule 是什么

在理解wrapper->plan 的调用链之前,我们先看看wrapper 里的核心结构————cache module。

具体上说,cache module 就是wrapper 真正的核心,是cpp 入口结构。它是被延后初始化的,因为主要是接口抽象类,不是实际资源,所以延后初始化也可以接受。cache module 被构建的时机是plan 接口调用时,根据backend的值和硬件当前情况,再次进行一次backend的判定,并根据backend的判定情况获取相应的cache module和挂载相应的具体接口。

首先是判断要不要使用flashattention3.

            if self._backend == "auto":
                self._backend = determine_attention_backend(
                    self.device,
                    PosEncodingMode[pos_encoding_mode].value,
                    use_fp16_qk_reduction,
                    self._custom_mask_buf is not None,  # use_custom_mask
                    q_data_type,
                    kv_data_type,
                )

当前的flahinfer 对fa3的支持不太好,比如fp8就不支持。当前flashinfer 仅在如下条件都满足的情况下才会使用fa3.

H系列的硬件,cuda 版本大于12.3,不使用fp8(q和kv),不使用query/key的reduction

随后获得真正执行的cache module。

           get_module_args = (
                q_data_type,
                kv_data_type,
                q_data_type,
                kv_indptr.dtype,
                head_dim_qk,
                head_dim_vo,
                PosEncodingMode[pos_encoding_mode].value,
                window_left >= 0,  # use_sliding_window
                logits_soft_cap > 0,  # use_logits_soft_cap
                use_fp16_qk_reduction,
            )
            self._cached_module = get_batch_prefill_module(self._backend)(
                *get_module_args
            )

由于flashinfer 还允许用户自定义实现接口和JIT,所以get_batch_prefill_module比较复杂,为了简化分析,我们只看AOT的情况(ahead-of-time compiling,也就是非JIT的情况),AOT下用的是默认已经编译好的接口,AOT的相关框架逻辑也比较复杂,如果感兴趣,可以自行阅读setup.py(https://github.com/flashinfer-ai/flashinfer/blob/95691060444252bc48c7338e9fe629e8c2fd4343/setup.py),这里可以看到AOT的prebuilt ops 相关的逻辑。但简而言之,我们只需要拿来用就可以了。_kernels 是一个内部生成的torch extention,具体逻辑我们先忽略,如下我们可以看到根据不同的backend,选择了不同的function 接口。

####def get_batch_prefill_module(backend)
               if backend == "fa2":
                    from . import _kernels

                    plan_func = _kernels.batch_prefill_with_kv_cache_plan
                    ragged_run_func = _kernels.batch_prefill_with_ragged_kv_cache_run
                    paged_run_func = _kernels.batch_prefill_with_paged_kv_cache_run
                else:
                    from . import _kernels_sm90

                    plan_func = _kernels_sm90.batch_prefill_with_kv_cache_sm90_plan
                    ragged_run_func = (
                        _kernels_sm90.batch_prefill_with_ragged_kv_cache_sm90_run
                    )
                    paged_run_func = (
                        _kernels_sm90.batch_prefill_with_paged_kv_cache_sm90_run
                    )

随后这些接口会被注入到模块里,并返回给wrapper。如果大家真的阅读这部分代码,还会观察到一些中间接口的注册逻辑,但是由于和主功能无关,我们这里就不说明了,为了篇幅考虑。

##def get_batch_prefill_module(backend)
        modules_dict[args] = SimpleNamespace(
                plan=plan_func,
                ragged_run=ragged_run,
                paged_run=paged_run,
            )
        return modules_dict[args]

此时,我们真正获得了prefill的模块接口函数。

wrapper->plan的调用链分析

一切数据准备好了,现在可以进行真正的plan了。plan 接口的参数很多,但实际用到的不多,我们以实际用到的进行分析。

def plan(
        self,
        qo_indptr: torch.Tensor,                               # 每一个请求输出的起始位置的数组
        paged_kv_indptr: torch.Tensor,                         # 下面三个就是每个请求对应token kv cache的稀疏表达
        paged_kv_indices: torch.Tensor,
        paged_kv_last_page_len: torch.Tensor,
        num_qo_heads: int,                                     # attention 里有几个query head,kv head
        num_kv_heads: int,
        head_dim_qk: int,                                      # 每个head的dim
        page_size: int,                                        # 每个page 包含几个token,实际上sglang 都只传过1
        causal: bool = False,                                  # 是否使用causal mask,默认不使用
        use_fp16_qk_reduction: bool = False,                   # 是否裁剪qk 维度,默认不使用
        window_left: int = -1,                                 # 如果是window attention,则需要设置,-1 说明是full attention
        logits_soft_cap,                                       # 控制注意力机制中的 logits 值的上限,部分模型会用比如gemni,grok, 默认设置0
        q_data_type: Union[str, torch.dtype] = "float16",      # query的数据类型,默认是fp16
        non_blocking: bool = False):                           # 是否采用非阻塞形式,默认不使用

主要几个参数,是从上游传下来的,可能有读者会疑惑,比如head_dim 为什么只有qk,vo的呢?以及data_type,这里其实接口也允许传,但是如果接口不提供,会对齐到qk的值,如下。

        q_data_type = canonicalize_torch_dtype(q_data_type)
        if kv_data_type is None:
            kv_data_type = q_data_type
        kv_data_type = canonicalize_torch_dtype(kv_data_type)
        if head_dim_vo is None:
            head_dim_vo = head_dim_qk

随后,需要准备一些中间数据结构,主要作用是赋值,从接口参数到wrapper内部成员参数。

        batch_size = len(qo_indptr) - 1
        # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
        qo_indptr_host = qo_indptr.to("cpu")
        paged_kv_indptr_host = paged_kv_indptr.to("cpu")
        paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu")
        kv_lens_arr_host = get_seq_lens(
            paged_kv_indptr_host, paged_kv_last_page_len_host, page_size
        )
        self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_(
            kv_lens_arr_host, non_blocking=non_blocking
        )

        total_num_rows = qo_indptr_host[-1]
        self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking)
        self._paged_kv_indptr_buf = paged_kv_indptr.to(
                self.device, non_blocking=non_blocking
            )
        self._paged_kv_indices_buf = paged_kv_indices.to(
                self.device, non_blocking=non_blocking
            )
        self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(
                self.device, non_blocking=non_blocking
            )
        self._cached_q_data_type = q_data_type
        self._cached_kv_data_type = kv_data_type

再接下来是获取当前具体应该调用的后端模块,也就是cache_module,其初始化逻辑可看上文对cache_module的介绍。

获得cache module后,也就获得了核心接口(比如plan,run等等),现在我们可以调用plan 接口了。调用链如下(我们选择flashattention2), 具体接口函数实现大家可以自己看,主要是功能逻辑。

_kernels.batch_prefill_with_kv_cache_plan(py)->
     BatchPrefillWithKVCachePlan(cu)->
         PrefillPlan

这里的核心在于PrefillPlan 这个函数,BatchPrefillWithKVCachePlan只是外部包了一层接口。大体逻辑如下, 所以核心在于PrefillPlan 这个函数,核心数据结构是plan_info。

  PrefillPlanInfo plan_info;
  cudaError_t status = PrefillPlan<IdType>(...., plan_info, ...)
  return plan_info.ToVector()

planinfo 是将forward batch的配置信息进行tiling 化,从一个请求维度的数组,变出CTA 维度的数组的过程,这个数据结构本身我们上面已经说明过了。所以我们直接看PrefillPlan 即可。为了简化逻辑,以下只展示最核心的代码。

### 1. 获得硬件配置
  int num_sm = 0;
  int dev_id = 0;
  FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
  FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
  int num_blocks_per_sm = 2;
  int max_grid_size = num_blocks_per_sm * num_sm;
  ## 注意这里的max_batch_size_if_split 即指如果进行tile,可以有多少batch
  ## 比如hopper 有144个sm,则 max_grid_size 为288,kv heads 如果有64个,则prefill 在CTA的batch 最多为4
  ## 依旧注意这是prefill 的max batch,decode 并不相同
  ## 从此我们可以看到prefill的batch-size 支持在flashinfer 这里也是比较小的
  uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;

### 2. 进行tiling,获得当前batch 要求下,推理侧CTA 最优的配置
### 这里的new_batch_size 就是tiling 化后的batchsize大小,new_batch_size >= batch_size
  auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec,
        qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
      PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads,
                             num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split,
                             enable_cuda_graph);

### 3. 把配置写入plan info,或者通过plan info 可访问,以及copy到paged_lock_int_buffer(host buffer)中
    // 这里一堆赋值

### 4. 把host buffer里的配置信息,copy到device buffer使gpu 可见
  size_t num_bytes_to_copy = int_allocator.num_allocated_bytes();
  FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
                                       cudaMemcpyHostToDevice, stream));

由于tiling 是一个相对重要的过程,所以我们这里也展示一下PrefillSplitQOKVIndptr的核心逻辑。

## 获得一些配置参数,比如gqa_group_size,用于展开以对齐计算量
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
## 由于sglang 使用 pagesize 恒为1,所以可以理解为就是min_kv_chunk_size = 128
const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U);

## 对齐计算量, 注意这里packed_qo_len_arr的维度是请求维度
## 之所以乘gqa_group_size,是由于每个kv header 面对的query 长度是原本长度的gqa_group_size倍
std::vector<int64_t> packed_qo_len_arr(batch_size), kv_len_arr(batch_size);
for (uint32_t i = 0; i < batch_size; ++i) {
    packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size);
    kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]);
}

#### 根据query len 进行tile(注意这里是非cuda graph的实现)
# 获取需要计算的总tokens 数量
int64_t sum_packed_qo_len = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
    sum_packed_qo_len += packed_qo_len_arr[i];
}
# 获得平均每请求计算token 计算量
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
# 获得一个tile 合适处理的计算量,这里FA2DetermineCtaTileQ里的策略感觉是根据经验选的tile
cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim);

# 计算总共需要多少tile
total_num_tiles_q = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
   total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q);
}

## 寻找合适kv chunk size
## 其实也是tiling 过程,综合考虑kv len的tiling 过程
# 举例1:如果我们只有一个请求,batchsize 为1,假设请求长度1000,kv_head 64,head_dim 128
# 此时cta_tile_q = 128,min_kv_chunk_size=128,max_batch_size_if_split=4
# 不考虑cuda graph,要求最终batchsize(tiles_q * tiles_kv) < max_batch_size_if_split, 
# 由于请求长度较长,tiles_q 都已经超过了4(1000/128),所以不能对kv做chunk,split_kv为false,kv_chunk_size为1000
# 举例2: 如果其他不变,请求只有256,则kv_chunk_size 可以为128。
# 另外,以上都是不开cudagraph的例子,如果打开了cudagraph,则可能在例子1 下依旧做kv cache的tiling
auto [split_kv, kv_chunk_size] =
      PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr,
                                     kv_len_arr, cta_tile_q, min_kv_chunk_size);

## 根据tiling的参数,重构配置信息的相关数组
## 以下核心数据结构的含义在planinfo 有过说明,可以自行脑补,核心逻辑是将请求维度的数组变出CTA维度的数组
## 最后这些核心数组都会返回,成为planinfo的一部分
uint32_t new_batch_size = 0;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
    const int64_t packed_qo_len = packed_qo_len_arr[request_idx];
    const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1);
    const int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q);
    const int64_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size);

    for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) {
      for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) {
        new_batch_size += 1;
        request_indices.push_back(request_idx);
        qo_tile_indices.push_back(q_tile_idx);
        kv_tile_indices.push_back(kv_tile_idx);
      }
    }

    int64_t qo_len = packed_qo_len / gqa_group_size;
    for (uint32_t row = 0; row < qo_len; ++row) {
      merge_indptr.push_back(merge_indptr.back() + num_tiles_kv);
    }
    o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv);
}

forward 整体过程说明

ok,我们终于来到了forward 过程(不容易啊不容易),现在只需要等model forward 过程中根据不同模型的情况更新完kvcache的信息,我们就可以拿着kvcache 去算logits和token了。

让我们再回头看看如下实现.

def forward_extend(self, forward_batch: ForwardBatch):
     self.attn_backend.init_forward_metadata(forward_batch)
     if self.is_generation:
         if forward_batch.input_embeds is None:
             return self.model.forward(
                    forward_batch.input_ids, forward_batch.positions, forward_batch
             )

先前说过,forward的过程是先init_forward_metadata ,然后后forward。但我们还需注意到init_forward_metadata 对一次完整的模型推理而言是一次性的,model forward 也是一次性的,一次model forward包括所有layer的forward,但是只有调用model 中attention layer的forward 才会真正进入到attentionbackend 的forward。这也是为什么叫做attentionbackend的原因,换言之,flashinfer 只实现了对attention的计算,不包括mlp等layer的情况。可以通过下图理解。

图中的planinfo 是由init_forward_metadata 过程中,由wrapper plan 获得的配置信息,作为当前batch执行的plan。

不过,有同学会问,难道sglang 只加速attention么?回答是:否。

细心的同学都会注意到sglang repo里有一个显眼的sgl-kernel的目录,这里通过torch extention的方式扩展了一些算子实现,比如RMSNorm,fused_moe 等等,而sglang的MLP 层其实默认是从这些extention的实现执行的。

https://github.com/sgl-project/sglang/blob/642ab418f31ead20c7ebd53516864bc9549d7217/sgl-kernel/csrc/torch_extension.ccl/csrc/torch_extension.cc

有兴趣的同学可以看看这里的算子,并追踪调用栈。

forward 实施细节-prefill 举例

一言不合,直接看代码。flashinfer_attentionbackend 里有外层骨架,比较清晰,简单过一下这部分。

    def forward_extend(
        self,                 
        q: torch.Tensor,                # query,key,value
        k: torch.Tensor,
        v: torch.Tensor,
        layer: RadixAttention,          # 必须是attention,radixAttention 是官方attention的基类
        forward_batch: ForwardBatch,    # forward batch 相关信息,具体字段请回顾上篇文章
        save_kv_cache=True,             # 默认save kv cache
    ):
        # 首先要获得本次forward 对应处理的wrapper
        # _get_wrapper_idx仅仅在encoder-decoder/sliding window架构下有用,self-attention 下直接返回0
        prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
            self._get_wrapper_idx(layer)
        ]
        # 本次forward的输出token所在数组索引, if 过程可以忽略,selfattention下==forward_batch.out_cache_loc
        cache_loc = (
            forward_batch.out_cache_loc
            if not layer.is_cross_attention
            else forward_batch.encoder_out_cache_loc
        )

        # logits 限制
        logits_soft_cap = layer.logit_cap

        if not self.forward_metadata.use_ragged:
            # 看有没有传kvcache, 如果model 传了kv cache,则在此设置进token_to_kv_pool
            # k_scale 和 v_scale 默认为None,对cache 内容进行整除,一般不用
            if k is not None:
                assert v is not None
                if save_kv_cache:
                    forward_batch.token_to_kv_pool.set_kv_buffer(
                        layer, cache_loc, k, v, layer.k_scale, layer.v_scale
                    )

            # 核心环节,数据到位,准备forward,开炫, 以下传参是参考gpt2
            o = prefill_wrapper_paged.forward(
                q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),  # query/tp 分割后
                forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),  # kv cache
                causal=not layer.is_cross_attention,                     # True
                sm_scale=layer.scaling,                                  # softmax 用的scaling
                window_left=layer.sliding_window_size,                   # -1
                logits_soft_cap=logits_soft_cap,                         # 0
                k_scale=layer.k_scale,                                   # None
                v_scale=layer.v_scale,                                   # None
            )

这里的forward 调用流程栈如下,挨个分析。

prefill_wrapper_paged.forward->
    prefill_wrapper_paged.run->
        cache_module.paged_run->
            _kernels.batch_prefill_with_paged_kv_cache_run->
                BatchPrefillWithPagedKVCacheRun->
                     BatchPrefillWithPagedKVCacheDispatched

prefill_wrapper_paged.run->BatchPrefillWithPagedKVCacheRun

这里主要是设置运行参数和检查。

### 1. 检查kvcache的数据类型和layout
        k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout)
        _check_cached_qkv_data_type(
            q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
        )
        stride_block = k_cache.stride(0)
        if self._kv_layout == "NHD":
            page_size = k_cache.shape[1]
            stride_n = k_cache.stride(1)
        else:
            page_size = k_cache.shape[2]
            stride_n = k_cache.stride(2)
### 2. 设置默认运行参数,这部分比较冗余,我们仅写一些gpt2 下的参数
      # window_left = -1
      # logits_soft_cap = 0.0
      # kv layout type
      # self._kv_layout.value
      # sm_scale = head_dim ** -0.5
      # rope_scale = 1.0
      # rope_theta = 1e4
      # return_lse = False
      ## 使用casual mask
      # mask_mode = MaskMode.CAUSAL.value
      ## kvcache的稀疏表示,fa2 运行中不会修改该表示,但是fa3 会需要修改
      # sparse_indices = self._paged_kv_indices_buf
      # sparse_indptr = self._paged_kv_indptr_buf
      # self._qo_indptr_buf
      ## 输出,这里的shape 计算,其实就是softmax(q*k)*v 的结果的维度,只是简化了过程
      # out = torch.empty(q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device)
      ## 以下是wrapper 内部配置
      # self._float_workspace_buffer,
      # self._int_workspace_buffer,
      # self._plan_info,
      ## 以下是qkv
      # q,
      # k_cache,
      # v_cache,
### 3. 进paged_run 跑kernel,返回输出即可
      self._cached_module.paged_run(*run_args)
      return out

接下来是时候cpp 上场了。BatchPrefillWithPagedKVCacheRun 这个函数其实逻辑也比较接近配参数,不过这个函数里使用了很多pytorch dispatch的宏进行逻辑分派,看上去有点不太易懂(如果不太熟悉pytorch的dispatch 宏的话,比如我)。但没有关系,我们抓住核心,设置PagedParams。基本数据结构内容都在run_args和plainfo 出现过,我们只介绍未出现的字段。

struct PagedParams {
  using DTypeQ = DTypeQ;
  using DTypeKV = DTypeKV;
  using DTypeO = DTypeO;
  using IdType = IdType;

  DTypeQ* q;
  paged_kv_t<DTypeKV, IdType> paged_kv; #封装了kcache和vcache
  IdType* q_indptr;
  DTypeO* o;
  float* lse;
  uint_fastdiv group_size;     # 即gqa group size

  {{ additional_params_decl }}
  uint32_t num_qo_heads;       # 有多少head (h)
  IdType q_stride_n;           # 一个batch 包含多少size nbytes(h*d)
  IdType q_stride_h;           # query 一个head 包含多少dim(d)
  int32_t window_left;

  IdType* request_indices;
  IdType* qo_tile_indices;
  IdType* kv_tile_indices;
  IdType* merge_indptr;
  IdType* o_indptr;
  bool* block_valid_mask;
  IdType* kv_chunk_size_ptr;
  uint32_t max_total_num_rows;
  uint32_t* total_num_rows;
  uint32_t padded_batch_size;
  bool partition_kv;           # 即split kv
}

BatchPrefillWithPagedKVCacheDispatched

进入这一部分,forward 真正进入面向硬件的部分。

首先计算需要的硬件资源,包括warps和mma , 寄存器和mem。

  ## CTA_TILE_Q > 64 , 为2,否则为1;MMA 矩阵运算能力是16,这里64 考虑4个warps的共同处理能力
  ## flashinfer fa2 模板考虑4个warps
  constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
  ## NUM_WARPS_Q * NUM_WARPS_KV == 4
  constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
  constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);

  # 注意dim的分割,后续算子内的计算和这个相关,block 维度按batch和head 区分
  # threads 内 y和z 分别指NUM_WARPS_Q和NUM_WARPS_KV的粒度
  dim3 nblks(padded_batch_size, 1, num_kv_heads);
  dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV);

  # 这里计算逻辑是NUM_MMA_x 是 HEAD_DIM_x 的1/16, 这里我理解就是mma的计算尺度一般是16*16*16
  constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16;
  constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16;

  # 获得max_smem_per_sm
  int dev_id = 0;
  FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
  int max_smem_per_sm = 0;
  FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
                                              cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
  // we expect each sm execute two threadblocks
  // TODO(Zihao): fix the following computation
  # 计算一个sm 可以用多少CTA(threadblocks),根据sharemem 计算
  # (num_cta_per_sm * (num_warps_q + nums_warps_kv)) * HEAD_DIM_QK * sizeof(DTypeQ) * MMA_SIZE)
  const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ) * 16) ? 2 : 1;
  # 一个threadblock 需要多少sharemem
  const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
  # 根据reg 数量,计算最多可以有多少mma 用于kv
  const uint32_t max_num_mma_kv_reg =
      (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
       !USE_FP16_QK_REDUCTION)
          ? 2
          : (8 / NUM_MMA_Q);
  // TODO(Zihao): fix the following computation
  # 根据sharemem,计算最多可以有多少mma 用于kv
  const uint32_t max_num_mma_kv_smem =
      (max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) /
      (2 * NUM_WARPS_KV);

  ## 双重约束 获得最终用于kv的mma数量
  NUM_MMA_KV = min(max_num_mma_kv_smem, max_num_mma_kv_reg)

硬件参数获得后,抽象成一个数据结构(使用cpp 模板生产)KernelTraits,这名字很直接(cuda kernel 需要用的硬件traits)。并根据计算结果,设置cudakernel 运行过程中可用空间。

  using KTraits = KernelTraits<MASK_MODE, CTA_TILE_Q, NUM_MMA_Q, NUM_MMA_KV, NUM_MMA_D_QK, NUM_MMA_D_VO,
                     NUM_WARPS_Q, NUM_WARPS_KV, POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeO,
                     DTypeQKAccum, typename Params::IdType, AttentionVariant>;

  # 根据计算得到需要sharemem,设置kernel 可用mem 空间
  size_t smem_size = sizeof(typename KTraits::SharedStorage);
  auto kernel = BatchPrefillWithPagedKVCacheKernel<KTraits, Params>;
  FLASHINFER_CUDA_CALL(
          cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

kernel launch ,为了简化篇幅,我们采用最简单的逻辑,不考虑spliting kv的情况。

  # 直接launch 就好,参数先前我都介绍过了
  if (tmp_v == nullptr) {
      // do not partition kv
      params.partition_kv = false;
      void* args[] = {(void*)&params};
      FLASHINFER_CUDA_CALL(
        cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
  }

最后稍微介绍一下kernelTraits 这个结构,这其实是一个cpp 模版,根据传参生成结构。大部分都是模版参数直接赋值,但其中有个子结构SharedStorage,我们需要稍微说明一下,会涉及后续的一些sharemem计算。

//shareStorage 本身也是模版函数过来的结构,调用如下
//其中CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16 ,该过程在KernelTraits的模版函数中
using SharedStorage = SharedStorageQKVO<NUM_WARPS_KV, CTA_TILE_Q, CTA_TILE_KV, HEAD_DIM_QK,
                                          HEAD_DIM_VO, DTypeQ, DTypeKV, DTypeO>;

=====================================模版函数如下========================
// 这里是个union,不过我们的场景只需要考虑第一个struct。
template <uint32_t NUM_WARPS_KV, uint32_t CTA_TILE_Q, uint32_t CTA_TILE_KV, uint32_t HEAD_DIM_QK,
          uint32_t HEAD_DIM_VO, typename DTypeQ, typename DTypeKV, typename DTypeO>
struct SharedStorageQKVO {
  union {
    struct {
      alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK];
      alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK];
      alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO];
    };
    struct {  // NOTE(Zihao): synchronize attention states across warps
      alignas(
          16) std::conditional_t<NUM_WARPS_KV == 1, float[1],
                                 float[NUM_WARPS_KV * CTA_TILE_Q * HEAD_DIM_VO]> cta_sync_o_smem;
      alignas(16) std::conditional_t<NUM_WARPS_KV == 1, float2[1],
                                     float2[NUM_WARPS_KV * CTA_TILE_Q]> cta_sync_md_smem;
    };
    alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO];
  };
};

=============== 还有一个相关结构是smem_t============================================
该结构我们无需贴代码,大家只需要知道这个结构存在的意义是为了抽象对sharemem的读写行为,并进行加速
上述结构关心的是资源分配的问题,smem_t 结构关心的是读写效率问题
=============== 最后是有关layout的管理==========================================
static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B;
static constexpr SwizzleMode SWIZZLE_MODE_KV =
      (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B;
static constexpr uint32_t KV_THR_LAYOUT_ROW = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 8;
static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4;

BatchPrefillWithPagedKVCacheKernel

最后,我们一起看看kernel的实现吧。相关源码比较多(1000行+),为了简化,我们依旧抠核心逻辑。

在此前我推荐一下 DefTruth:[Attention优化][2w字] 原理篇: 从Online-Softmax到FlashAttention V1/V2/V3(https://zhuanlan.zhihu.com/p/668888063) 和 zihao 的 http

https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf


/courses/cse599m/23sp/notes/flashattn.pdf

    auto block = cg::this_thread_block();
    const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr);
    // 理解了上面block 和 thread dim的划分,我们应该可以比较容易理解这一部分,bx即batchidx
    const uint32_t bx = blockIdx.x, lane_idx = threadIdx.x, warp_idx = get_warp_idx<KTraits>(),
                   kv_head_idx = blockIdx.z;
    // group_size 即 gqa 参数
    const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads;
    const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx],
                   kv_tile_idx = kv_tile_indices[bx];
    // 上述可以理解为获取入参
    extern __shared__ uint8_t smem[];
    auto& smem_storage = reinterpret_cast<typename KTraits::SharedStorage&>(smem);
    // 可以忽略这个变量,基本就是params的copy,只是方便管理罢了
    AttentionVariant variant(params, /*batch_idx=*/request_idx, smem);
    // 下面是几个关键的中间变量
    // s_frag存储softmax(q*k) 的结果,o_frag 存储 s*v的结果,m与d见图,属于online-softmax 计算的中间结果
    DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; 
    alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8];
    DTypeQKAccum m[NUM_MMA_Q][2];
    float d[NUM_MMA_Q][2];
    float rope_freq[NUM_MMA_D_QK / 2][4];
    # 上述数组的前一维度比较好理解,后一维度与mma layout 相关(不过我还没理解完这部分意思)
    init_states<KTraits>(variant, o_frag, m, d);
    //加载query 数据从global mem to sharemem
    smem_t<SWIZZLE_MODE_Q> qo_smem(smem_storage.q_smem);
    load_q_global_smem<KTraits>(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h,
                                group_size, &qo_smem);
    // load_q_global_smem 内load 操作也是调用cp 命令操作的,所以这里可以调用commit_group进行管理
    cp_async::commit_group();
    // 实例化k v sharemem
    smem_t<SWIZZLE_MODE_KV> k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem);
    // 加载kv 数据从global mem到sharemem
    page_produce_kv<false, KTraits>(k_smem, &k_smem_offset_w, paged_kv, 0, thr_local_kv_offset,
                                    chunk_size);
    cp_async::commit_group();
    page_produce_kv<true, KTraits>(v_smem, &v_smem_offset_w, paged_kv, 0, thr_local_kv_offset,
                                   chunk_size);
    cp_async::commit_group();
    // 进行计算
    #pragma unroll 1
    for (uint32_t iter = 0; iter < num_iterations; ++iter) {
      packed_page_iter_base += CTA_TILE_KV;
      // 进行计算前,确保qk数据到位,v 可以不加载完(qk 不需要v)
      // wait_group 1 指可以有一个挂起,最近挂起的v可以不完成
      cp_async::wait_group<1>();
      block.sync();

      // compute attention score
      compute_qk<KTraits>(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);

      // logits transform,比如先前的logits soft cap 限制,应用一下
      logits_transform<KTraits>(
          params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base,
          chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv<KTraits>()) * NUM_MMA_KV * 16,
          qo_len, kv_len, group_size, s_frag);

      // compute m,d states in online softmax
      update_mdo_states<KTraits>(variant, s_frag, o_frag, m, d);

      // 块内同步
      block.sync();
      // 这一批的k 用完了,提前加载下一批的k,注意要在qk 计算确认完成后执行
      page_produce_kv<false, KTraits>(k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV,
                                      thr_local_kv_offset, chunk_size);
      cp_async::commit_group();
      // 这里其实等的是v,因为wait group 1允许有一个挂起,刚刚挂起k,所以v 被等待完成
      cp_async::wait_group<1>();
      block.sync();

      // compute sfm*v,计算output
      compute_sfm_v<KTraits>(&v_smem, &v_smem_offset_r, s_frag, o_frag, d);

      block.sync();
      // 块内同步完成,为下一个iterate 加载v 数据
      page_produce_kv<true, KTraits>(v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV,
                                     thr_local_kv_offset, chunk_size);
      cp_async::commit_group();
    }

    // 确保所有挂起逻辑完成数据同步
    cp_async::wait_group<0>();
    // 块内同步
    block.sync();

    finalize_m<KTraits>(variant, m);

    // threadblock synchronization
    threadblock_sync_mdo_states<KTraits>(o_frag, &smem_storage, m, d, warp_idx, lane_idx);

    // normalize d,这一步使用d 进行o的归一处理
    normalize_d<KTraits>(o_frag, m, d);

    const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size;

    // write_back,将计算结果写回global mem
    write_o_reg_gmem<KTraits>(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len,
                              /*o_stride_n=*/
                              partition_kv ? num_kv_chunks * o_stride_n : o_stride_n,
                              /*o_stride_h=*/o_stride_h, group_size);

上面这部分逻辑我们只抠了主体代码,其实省略了许多index 的计算细节,这些细节本身也是加速的关键,但奈何笔者理解能力有限,以及flashinfer的注释有些不足,尚不能完全理解,后面理解了再单独完整讲一篇kernel并行计算优化吧。

decode 的 forward 行为-与prefill 行为的辨析

先前我们已经完整过完了flashinfer Prefill 在fa2下的主体代码。 接下来我们比较一下decode 与prefill的差异。

decode 也有自己decodeWrapper,自然也有初始化,plan 和run的接口。

初始化逻辑基本是一致的,只不过少了q相关的字段,以及decode 会考虑不用tensor-core的可能性(should_use_tensor_core)。在sglang侧和flashinfer 的python层都没有特别值得注意的行为。但当我们去看cpp kernel 接口的实现,我们会看到一些不同。

nums_stage_smem

首先我们会经常看到一个变量名 nums_stage_smem。

work estimation

decode plan 的过程与其他几乎一样,也只是少了query 相关的信息。但多了一步work estimation,以评估即将运行的kernel 函数所占用的资源情况。

FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages,
                                            new_batch_size, gdy, batch_size, indptr_h, num_qo_heads,
                                            page_size, enable_cuda_graph, stream));

这里的评估的结果主要影响split_kv, max_grid_size, kv_chunk_size_in_pages, new_batch_size,对decode batch的编排影响也比较大。其实这里的部分逻辑类似于prefill的PrefillSplitQOKVIndptr,decode 虽然也有DecodeSplitQOKVIndptr ,但实现不同,寻找最佳kvchunk size和batch_size的动作放在了work_estimation_func里。

我们简单过一下代码,首先是选择合适的tiling 依据。prefill的tiling 过程中query len是一个重要考虑,但是decode 这边query len为1,此时由于tiling的维度就仅仅包括kv len,另外也混入了head dim层切割。

constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t bdx = HEAD_DIM / vec_size;    // 从这里可以看到tile 的依据来自对head_dim的切割
static_assert(bdx <= 32);
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U;
const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;

然后是计算可能需要的sharemem 数量

    gdy = num_kv_heads;
    const uint32_t smem_size =
        2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
        std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));

    auto kernel =
        BatchDecodeWithPagedKVCacheKernel<POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx,
                                          vec_size, bdx, bdy, bdz, AttentionVariant, Params>;
    int num_blocks_per_sm = 0;
    int num_sm = 0;
    int dev_id = 0;
    FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
    FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
    FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
                                                                       num_threads, smem_size));
    max_grid_size = num_blocks_per_sm * num_sm;

这里cudaOccupancyMaxActiveBlocksPerMultiprocessor这个接口可用于给定threads和sharemem以及kernel的情况下(也就是给定一个block需要的资源),当前在获得sm 内可以拿到几个这种block。

从flashinfer的同学那边得知, decode的模板后续可能也不再维护了,所以我们就浅尝辄止,接下来回过头来看看cudaGraph 下的工作方式。

cudaGraphRunner 下的行为

游凯超:一文读懂cudagraph(https://zhuanlan.zhihu.com/p/700224642),首先推荐一篇游凯超对于cudagraph的文章,对于cudagraph的原理说明是相当透彻的。这里我们详细介绍一下sglang的cudaGraphRunner 里的基本实现框架。这一部分其实和vllm 可能会比较类似。

初始化

其初始化结构很简单,这里最重要的就是graphs和outputbuffers,graph是cudaGraph 结构的字典(key 是batchsize),outputbuffers 是graph 字典一一对应的输出buffer 字典。

def __init__(self, model_runner: ModelRunner):
        # Parse args
        self.model_runner = model_runner
        self.graphs = {}
        self.output_buffers = {}
        self.enable_torch_compile = model_runner.server_args.enable_torch_compile
        self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
        self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
        # 这部分是distributed 相关的配置,等下一篇讲分布式再来看吧)
        self.enable_dp_attention = model_runner.server_args.enable_dp_attention
        self.tp_size = model_runner.server_args.tp_size
        self.dp_size = model_runner.server_args.dp_size
        
        # Batch sizes to capture,这里的bs 和 graph相关,基本是硬编码和配置
        # 不过这里有一个点,capture_bs 可以大于Max_running_request,为了尽量攒大一点graph的batchsize
        self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
        # 这里有意思的是,sglang的cudaGraph 似乎只有decode 模式(投机推理情况下+verify)
        self.capture_forward_mode = ForwardMode.DECODE
        self.num_tokens_per_bs = 1
        
        self.max_bs = max(self.capture_bs)
        self.max_num_token = self.max_bs * self.num_tokens_per_bs
        ## init_cuda_graph_state 其实没啥,单独为cuda graph 准备了一个kv indice的数组资源
        self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
        ## 这个是padding的填充值,sglang 固定用0
        self.seq_len_fill_value = (
            self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
        )
        # 为graph 输入开辟空间
        # Graph inputs
        with torch.device("cuda"):
            self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
            self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
            self.seq_lens = torch.full(
                (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
            )
            self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
            self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
            self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)

        # Capture
        try:
            with self.model_capture_mode():
                self.capture()

以上这段代码是cudaGraphRunner的初始化部分主体代码,不过我们稍微关注4个函数的实现,为了更好理解graphrunner的工作原理。分别如下

# 用于获取capture 的graph 总数,且每种graph的尺寸
def get_batch_sizes_to_capture(model_runner: ModelRunner):
    server_args = model_runner.server_args
    capture_bs = server_args.cuda_graph_bs

    if capture_bs is None:
        if server_args.speculative_algorithm is None:
            if server_args.disable_cuda_graph_padding:
                capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
            else:
                # 默认走这个branch,也就是capture 最多160的bs,[1,2,4,8,16,24,32,40,48,,,160]
                capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
        else:
            capture_bs = list(range(1, 33))

    # 大部分情况下,max_running_queue 大于capture_bs,但是在一些显存紧张的场景下
    # (比如一些小gpu 或者加载权重完,没有什么显存可以留给推理kvcache的情况下)
    # max_running_queue较小,此时为了更好的capture,capture_bs 最大控制到max_running_queue
    if max(capture_bs) > model_runner.req_to_token_pool.size:
        # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
        # is very small. We add more values here to make sure we capture the maximum bs.
        capture_bs = list(
            sorted(
                set(
                    capture_bs
                    + [model_runner.req_to_token_pool.size - 1] #增加max_running_queue同宽size
                    + [model_runner.req_to_token_pool.size]
                )
            )
        )
    
    capture_bs = [
        bs
        for bs in capture_bs
        if bs <= model_runner.req_to_token_pool.size # 削去max_running_queue 以上的bs
        and bs <= server_args.cuda_graph_max_bs      # 又一重限制,不过对于显存富余的场景(大于24GB),默认160
    ]

    # torch_compile_max_bs 默认32, 且仅在enable_torch_compile 开启时生效
    compile_bs = (
        [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
        if server_args.enable_torch_compile
        else []
    )
    return capture_bs, compile_bs

如上是一个大体的cudagraph runner的初始化。接下来我们看看capture的细节。

Capture() 过程

首先我们看看graph capture的context 构造过程。

# 对外capture_graph的调用函数,注意到这里context 加了tp和pp 两层通信上组上的限制
@contextmanager
def graph_capture():
    with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(
        context
    ):
        yield context

@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream

# 从上述可见capture_graph 与通信组相关,我们来看看原因
class GroupCoordinator:
    ...
    @contextmanager
    def graph_capture(
        self, graph_capture_context: Optional[GraphCaptureContext] = None
    ):
        # 默认context 可传None,在上述调用中tp的context 就是创建的,pp的context 来自于tp 创建的
        if graph_capture_context is None:
            # 获得一个Stream
            stream = torch.cuda.Stream()
            # 基于Stream 构造一个GraphCaptureContext对象(如上,很简单)
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

        # 通信组的comm 对象
        ca_comm = self.ca_comm
        maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        # 等待当前默认stream的操作完成后,才能进入context capture的初始化
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

        # 设置新stream 作为当前线程cuda的默认stream
        # 如果使用graph,那么只能支持pynccl(默认nccl实现的py 接口)或者sglang custom_allreduce
        with torch.cuda.stream(stream), maybe_ca_context:
            # In graph mode, we have to be very careful about the collective
            # operations. The current status is:
            #     allreduce \ Mode   |  Eager  |  Graph  |
            # --------------------------------------------
            # custom allreduce       | enabled | enabled |
            # PyNccl                 | disabled| enabled |
            # torch.distributed      | enabled | disabled|
            #
            # Note that custom allreduce will have a runtime check, if the
            #  tensor size is too large, it will fallback to the next
            #  available option.
            # In summary: When using CUDA graph, we use
            #  either custom all-reduce kernel or pynccl. When not using
            #  CUDA graph, we use either custom all-reduce kernel or
            #  PyTorch NCCL. We always prioritize using custom all-reduce
            #  kernel but fall back to PyTorch or pynccl if it is
            #  disabled or not supported.
            pynccl_comm = self.pynccl_comm
            maybe_pynccl_context: Any
            if not pynccl_comm:
                maybe_pynccl_context = nullcontext()
            else:
                maybe_pynccl_context = pynccl_comm.change_state(
                    enable=True, stream=torch.cuda.current_stream()
                )
            with maybe_pynccl_context:
                yield graph_capture_context

获得了graph capture的stream和context,我们可以看看capture的实现了。

   def capture(self):
        # 第一步获取graph的上下文
        with graph_capture() as graph_capture_context:
            # 获取stream和可用显存
            self.stream = graph_capture_context.stream
            avail_mem = get_available_gpu_memory(
                self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
            )
            
            # 反向遍历bs,以希望获得更好的显存共享
            # Reverse the order to enable better memory sharing across cuda graphs.
            capture_range = (
                tqdm.tqdm(list(reversed(self.capture_bs)))
                if get_tensor_model_parallel_rank() == 0
                else reversed(self.capture_bs)
            )
            for bs in capture_range:
                # 只在tp 0 上打印
                if get_tensor_model_parallel_rank() == 0:
                    avail_mem = get_available_gpu_memory(
                        self.model_runner.device,
                        self.model_runner.gpu_id,
                        empty_cache=False,
                    )
                    capture_range.set_description(
                        f"Capturing batches ({avail_mem=:.2f} GB)"
                    )

                # 在默认情况下,patch_model 直接返回model.forward,
                # 在enable_torch_compile的情况下,返回一个torch 编译优化的版本,具体我们后续说
                with patch_model(
                    self.model_runner.model,
                    bs in self.compile_bs,
                    num_tokens=bs * self.num_tokens_per_bs,
                    tp_group=self.model_runner.tp_group,
                ) as forward:
                    # 基于bs 和 forward,进行capture
                    (
                        graph,
                        output_buffers,
                    ) = self.capture_one_batch_size(bs, forward)
                    self.graphs[bs] = graph
                    self.output_buffers[bs] = output_buffers

                # Save gemlite cache after each capture
                # 每次cuda graph capture 完,保存gemlite cache 配置,便于后续加速
                save_gemlite_cache()

可以看到,在初始化时,sglang 就调用过capture,所以实际上graph的构建绝大部分情况下,是初始化时期就完成了。但是在新版本的代码中,replay 过程会进行检查看是否需要重新capture,这一点我们后续看replay代码再说。继续看上述代码中的两块关键函数。

首先是patch_model,该函数返回一个model 层面的forward 接口,默认情况下返回Model 本身的forward 接口,但是sglang 支持torch compile的情况,所以也可以发挥torch 编译优化后的接口。如下是代码分析。

# 该函数是一个设置model 每一层forward的调用函数的功能,大家可以看到几个点:
# reverse True表明采用sglang保留的实现(custom kernel,forward_cuda看代码都是调用sglang自己的算子)
# 注意这里CustomOp 说明是一个sglang 内部自定义的layer,但是不一定会调用custom kernel,只是支持调用customKernel
# forward_native 就是使用torch 原生接口实现customOp, 
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
    for sub in model._modules.values():
        if isinstance(sub, CustomOp):      # 以下操作仅仅对customOp执行
            if reverse:
                sub._forward_method = sub.forward_cuda
                setattr(sub, "is_torch_compile", False)   #表明编译未完成,即如果torch.compile,可以编译
            else:
                # NOTE: Temporarily workaround MoE
                if "FusedMoE" in sub.__class__.__name__:
                    if num_tokens == 1:
                        # The performance of torch.compile on this layer is not always good when bs > 1,
                        # so we decide to only use torch.compile when bs =1
                        sub._forward_method = fused_moe_forward_native
                else:
                    sub._forward_method = sub.forward_native
                setattr(sub, "is_torch_compile", True)   #表明编译完成,即便torch.compile,也不会执行编译操作
        if isinstance(sub, torch.nn.Module):
            _to_torch(sub, reverse, num_tokens) # 递归嵌套处理子模块


# 默认下就是else 分支,原地返回model.forward, 我们主要介绍开启torch 编译优化下的执行逻辑
@contextmanager
def patch_model(
    model: torch.nn.Module,
    enable_compile: bool,
    num_tokens: int,
    tp_group: GroupCoordinator,
):
    """Patch the model to make it compatible with with torch.compile"""
    backup_ca_comm = None

    try:
        if enable_compile:
            # compile 前reverse=False,我理解是为了方便torch compile 优化
            _to_torch(model, reverse=False, num_tokens=num_tokens)
            backup_ca_comm = tp_group.ca_comm
            # Use custom-allreduce here.
            # We found the custom allreduce is much faster than the built-in allreduce in torch,
            # even with ENABLE_INTRA_NODE_COMM=1.
            # tp_group.ca_comm = None
            yield torch.compile(
                torch.no_grad()(model.forward),   # torch.no_grad 即编译优化过程中不考虑梯度更新,因为梯度更新对推理没有用
                mode="max-autotune-no-cudagraphs", #优化过程中最大程度使用自动调优,且不使用cudagraph(因为sglang 有自己的cudagraph 优化) 
                dynamic=False, #说明输入形状是静态的,torch 可以进行更深入的优化
            )
        else:
            yield model.forward
    finally:
        if enable_compile:
            # capture 完,model 重新恢复custom kernel 实现
            _to_torch(model, reverse=True, num_tokens=num_tokens)
            tp_group.ca_comm = backup_ca_comm

# 如下则是torch compile的设置
def set_torch_compile_config():
    import torch._dynamo.config
    import torch._inductor.config

    torch._inductor.config.coordinate_descent_tuning = True
    torch._inductor.config.triton.unique_kernel_names = True
    torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future

    # FIXME: tmp workaround
    torch._dynamo.config.accumulated_cache_size_limit = 1024
    if hasattr(torch._dynamo.config, "cache_size_limit"):
        torch._dynamo.config.cache_size_limit = 1024

现在我们搞懂了capture 过程中调用的forward 接口的内容,接下来我们看看capture 过程中forwardbatch的内容组成,这一部分主要在capture_one_batch_size里。为了可读性,我们简化一下代码。

def capture_one_batch_size(self, bs: int, forward: Callable):
        graph = torch.cuda.CUDAGraph()            #构建cudagraph 实例, 获取stream
        stream = self.stream
        num_tokens = bs * self.num_tokens_per_bs  #预备所需要的tokens

        # Graph inputs
        # 注意cudagraph 中也是类似的forwardbatch和数组,只不过都是静态的batch,无需动态生成
        # 为了突出重点,我扔掉了投机推理、dp 相关的逻辑,仅仅讨论最简单情况下的流程
        input_ids = self.input_ids[:num_tokens]
        req_pool_indices = self.req_pool_indices[:bs]
        seq_lens = self.seq_lens[:bs]
        out_cache_loc = self.out_cache_loc[:num_tokens]
        positions = self.positions[:num_tokens]
        encoder_lens = None
        mrope_positions = self.mrope_positions[:, :bs]

        # forward batch 传参,和disable-cuda-graph 的区别是cuda graph下这里数组都是固定的
        forward_batch = ForwardBatch(
            forward_mode=self.capture_forward_mode,
            batch_size=bs,
            input_ids=input_ids,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            req_to_token_pool=self.model_runner.req_to_token_pool,
            token_to_kv_pool=self.model_runner.token_to_kv_pool,
            attn_backend=self.model_runner.attn_backend,
            out_cache_loc=out_cache_loc,
            seq_lens_sum=seq_lens.sum(),
            encoder_lens=encoder_lens,
            return_logprob=False,
            positions=positions,
            global_num_tokens_gpu=global_num_tokens,
            gathered_buffer=gathered_buffer,
            mrope_positions=mrope_positions,
            spec_algorithm=self.model_runner.spec_algorithm,
            spec_info=spec_info,
            capture_hidden_mode=self.capture_hidden_mode,
        )

        # Attention backend,在attention backend 中初始化cuda graph 信息
        self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
            bs,
            num_tokens,
            req_pool_indices,
            seq_lens,
            encoder_lens,
            forward_batch.forward_mode,
            forward_batch.spec_info,
        )

        # Run and capture, forward 参考上述patch_model的返回
        def run_once():
            # Clean intermediate result cache for DP attention
            forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None

            logits_output = forward(input_ids, forward_batch.positions, forward_batch)
            return logits_output.next_token_logits, logits_output.hidden_states

        # 这两次run_once我没有完全理解,第一次run_once 应该是提前获取相关数据(也可以理解为warmup的一部分)
        # 有些动作,比如cudaMalloc 分配显存是不允许的,而run_once 过程中会分配一些中间buffer
        # 如果buffer 已经分配过了则不会再次分配
        # 但第二次,我没看懂,可能是为了确保没有问题?
        for _ in range(2):
            torch.cuda.synchronize()
            self.model_runner.tp_group.barrier()

            run_once()

        # 构建静态图
        global global_graph_memory_pool
        with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
            out = run_once()

        global_graph_memory_pool = graph.pool()
        return graph, out

这里的init_forward_metadata_capture_cuda_graph 逻辑上比较简单,也是赋值metadata而已。

在最新的代码中,在replay 过程中也按需capture,不过这和投机推理有关,让我先忽略一下。

以上我们就整理完了cudaGraph runner 初始化的代码。

Replay() 过程

cudaGraph 的replay 其实就是正式forward的过程。我们观察到cudaGraphRunner 的forward 过程如下,如果使能cudaGraph,则通过cudaGraphRunner的replay 执行。

 def forward(
        self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
    ) -> LogitsProcessorOutput:
        if (
            forward_batch.forward_mode.is_cuda_graph()
            and self.cuda_graph_runner
            and self.cuda_graph_runner.can_run(forward_batch)
        ):
            return self.cuda_graph_runner.replay(
                forward_batch, skip_attn_backend_init=skip_attn_backend_init
            )
        ........

现在我们看看replay 接口的实现,从上述接口中传递了forward_batch。

   def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
        self.recapture_if_needed(forward_batch) # 非投机推理,基本忽略

        raw_bs = forward_batch.batch_size       
        raw_num_token = raw_bs * self.num_tokens_per_bs

        # Pad,获取raw_bs 的index
        index = bisect.bisect_left(self.capture_bs, raw_bs)
        bs = self.capture_bs[index]
        # 如果是之前不支持的bs 类型,对系统值进行初始化,seq_len 默认1,out_cache_loc 默认0
        if bs != raw_bs:
            self.seq_lens.fill_(1)
            self.out_cache_loc.zero_()

        # Common inputs,赋值入flashinfer的数组
        self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
        self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
        self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
        self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
        self.positions[:raw_num_token].copy_(forward_batch.positions)
        if forward_batch.decode_seq_lens_cpu is not None:
            if bs != raw_bs:
                self.seq_lens_cpu.fill_(1)
            self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)

        if self.is_encoder_decoder:
            self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
        if forward_batch.mrope_positions is not None:
            self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)

        if hasattr(forward_batch.spec_info, "hidden_states"):
            self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states

        # Attention backend,这里也仅仅是再次设置metadata,主要是设置pagetable
        self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
            bs,
            self.req_pool_indices,
            self.seq_lens,
            forward_batch.seq_lens_sum + (bs - raw_bs),
            self.encoder_lens,
            forward_batch.forward_mode,
            forward_batch.spec_info,
            seq_lens_cpu=self.seq_lens_cpu,
        )

        # Replay,replay 时真正launch kernel,返回output
        self.graphs[bs].replay()
        next_token_logits, hidden_states = self.output_buffers[bs]

        # 注意这里主要做split,outputbuffer 里有完整一个bs的logits和hidden_stats
        # 但是capture的bs和当前forward bs 可能不一样,所以要过滤一下
        logits_output = LogitsProcessorOutput(
            next_token_logits=next_token_logits[:raw_num_token],
            hidden_states=(
                hidden_states[:raw_num_token] if hidden_states is not None else None
            ),
        )
        return logits_output

好,这里为止,我们基本介绍完了forward和backend 流程里的主要环节,希望可以帮助大家更好理解代码


– The End –


GiantPandaLLM

长按二维码关注我们

本公众号专注:

1. 技术分享;

2. 学术交流

3. 资料共享

欢迎关注我们,一起成长!



(文:GiantPandaCV)

发表评论

×

下载每时AI手机APP

 

和大家一起交流AI最新资讯!

立即前往