福生无量摸鱼天尊

triton is all you need 之 GEMM

2025/09/07
16
0

代码参考了傅哥,请b站关注我是傅傅猪喵,谢谢喵!

Triton DSL是以BLOCK tile为中心的Python DSL。与CUDA相比,Triton的使用者无法控制所有细节,因为某些优化是自动完成的,但是在Triton编译器的逐层编译优化之下也可以获得与Cuda相近甚至超过的性能。另外,Triton的编写和调试更加简单,而且学习成本更低。

通过Triton配套的编译器,Triton能够将这些高级抽象代码自动转换为高度优化的PTX指令。

CUDA中,我们可以自定义每个block里面thread的Dim,Triton不一样,Triton 不会将块进一步分解为线程(threads),而是以块为基本单位进行向量化操作。如下面Triton版的add:

@triton.jit
def vector_add_kernel(
        x_ptr,  # 第一个向量的指针
        y_ptr,  # 第二个向量的指针
        output_ptr,  # 输出向量的指针
        n_elements,  # 向量的大小
        BLOCK_SIZE: tl.constexpr,  # 每个 block 的大小
):
    # 获取当前程序的全局索引
    pid = tl.program_id(axis=0)

    # 计算当前 block 的起始和结束索引
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # 创建一个掩码以防止越界访问
    mask = offsets < n_elements

    # 从全局内存加载数据
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # 执行向量加法
    output = x + y
    # 将结果存储回全局内存
    tl.store(output_ptr + offsets, output, mask=mask)

triton算子大致过程如下:

  • 每个块的唯一标识通过 pid = tl.program_id(axis=0) 来获取,其中 pid 表示当前块在全局范围内的索引。

  • 通过pidBLOCK_SIZE就能确定处理数据的起始位置,然后通过加上一个便宜量offsets就可以开始处理数据了。

  • 后续其实就是pytorch版的cuda程序,但是需要注意的是,这里并不用使用threadIdxdataIdx进行对齐计算,直接调包就行,triton会帮我们编译,这里有很多相似之处,请继续往下看。

矩阵乘法 matmul

请记住,每一个算子都需要进行优化的原因是不同的硬件对这边不同的数据的时候,都会有不同的性能表现,所以请学习通用的算子的优化思路,摒弃固定的思路。

  • 由于在矩阵乘法中,矩阵乘法通常跟输入token长度有关,现在GPT-5 支持高达 400K tokens,所以都要

triton的矩阵乘法比cuda要好写的多,因为不需要对齐,但是还是有多种优化,这里由于有cuda的基础,我们直接看kernel:


@triton.jit
def matmul_kernel(
    # 矩阵指针
    a_ptr, b_ptr, c_ptr,
    # 矩阵维度
    M, N, K,
    # 步长参数
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # 块大小
    BLOCK_SIZE_M: tl.constexpr, 
    BLOCK_SIZE_N: tl.constexpr, 
    BLOCK_SIZE_K: tl.constexpr,
):
    """Triton矩阵乘法内核 C = A @ B"""

    """============================================== PID ===================================================="""
    # 获取当前程序ID
    pid = tl.program_id(axis=0)
    
    # 计算块索引
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)  # M方向的块数量
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)  # N方向的块数量
    
    # 根据program ID计算当前块在M和N方向的索引
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n
    
    # 创建偏移量范围
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M  # A矩阵行偏移
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N  # B矩阵列偏移
    offs_k = tl.arange(0, BLOCK_SIZE_K)  # K维度偏移

    """============================================== Init ===================================================="""
    # 初始化指针
    # A矩阵指针:行偏移 * 行步长 + 列偏移 * 列步长
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    # B矩阵指针:行偏移 * 行步长 + 列偏移 * 列步长
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
    
    # 初始化累加器,用于存储部分乘积结果
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    
    """============================================== 计算 ===================================================="""
    # 主计算循环:沿K维度分块处理
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # 加载数据块,使用mask防止越界访问
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        
        # 计算部分结果并累加到accumulator中
        accumulator += tl.dot(a, b)
        
        # 推进指针到下一个K维度的块
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    
    # 将累加器转换为float32格式
    c = accumulator.to(tl.float32)
    
    # 存储结果到输出矩阵C
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)  # C矩阵行偏移
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)  # C矩阵列偏移
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]  # C矩阵指针
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)  # 防止越界的掩码
    tl.store(c_ptrs, c, mask=c_mask)  # 存储结果