福生无量摸鱼天尊

算子进阶 —— 通算融合

2025/09/19
17
0

(施工ing)

概述

我们知道,算子的作用是计算,那在整个体系中,我们的核心目标是拉满GPU的利用率。

在现代分布式体系中,多GPU之间同时存在着计算、内存访问和通信这三种基本活动,为了服务于我们的核心目标,我们需要尽可能的将通信时间和访存时间放在计算时间内,使得GPU不存在运算时间的泡泡。

大模型分布式系统执行的核心组件是并行的GPU。提升GPU运算的最佳方法,是通过计算与通信的重叠来实现。这种重叠可以通过两种方式达成:算子分解或通信内核融合。虽然算子分解实现起来简单,但往往导致性能欠佳。而将通信内核与计算内核融合,则需要更多的思考和计算。

故此最原始的想法就是用计算时间掩盖通信时间,通过大批量的数据直接发送到GPU,减少kernel开启和关闭、通信的开销。

Triton-distributed

Triton-distributed是字节seed团队开发的Triton 编译器的扩展。对于 LLM 来说,分布式优化的关键要求是计算通信重叠。以前,在小规模分布式训练/推理中,通信开销并不是一个关键的成本问题。然而,随着集群数量呈指数级增长,计算与通信重叠变得至关重要。

Triton-Distributed编译器技术栈

Triton-Distributed 编程模型遵循 MPMD(multiple programs multiple data)模型。其模型核心就是异步的处理信号和数据,关键技术有三个:symmetric memory 对称内存,signal exchange 信号交换,和async-task 异步任务。

  • symmetric memory 对称内存:每个 rank 在全局范围内分配一个相同大小的内存缓冲区。每个内存缓冲区具有独立的地址空间,没有通一的虚拟地址空间,访问只靠通信原语。

  • signal exchange 信号交换:每个 rank 上的操作以一致的方式通过信号相互通信。有一组固定的信号操作,包括设置信号值、增加信号值、检查信号值以及对给定信号进行自旋锁。

  • async-task 异步任务:所有操作都是异步的,即使在同一 rank 上,操作也是异步的。

Triton-distributed 提供了一套易于使用的原语,用于支持开发计算-通信融合的分布式kernel。这些原语分为低层次原语和高层次原语。目前,seed已经发布了低层次原语,并计划在未来发布高层次原语。代码位于Triton-distributed/python/triton_dist/language/

用于通信的低级原语有:

1、Context Querying Primitives 上下文查询原语

rank(axis=-1, _builder=None)
num_ranks(axis=-1, _builder=None)
symm_at(ptr, rank, _builder=None)

2、Signal Control Primitives 信号控制原语

wait(barrierPtrs, numBarriers, scope: str, semantic: str, _builder=None)
consume_token(value, token, _builder=None)
notify(ptr, rank, signal=1, sig_op="set", comm_scope="inter_node", _builder=None)

3、NVSHMEM-related Primitives NVSHMEM 相关原语

除了原语之外,Triton-distributed 还将所有 nvidia shared memory 通信原语公开给 Python,允许用户纯粹用 Python 编写通信内核。位于triton.language.extra.libshmem_device

my_pe()
n_pes()
int_p(dest, value, pe)
remote_ptr(local_ptr, pe)
barrier_all()
barrier_all_block()
barrier_all_warp()
sync_all()
sync_all_block()
sync_all_warp()
quiet()
fence()
getmem_nbi_block(dest, source, bytes, pe)
getmem_block(dest, source, bytes, pe)
getmem_nbi_warp(dest, source, bytes, pe)
getmem_warp(dest, source, bytes, pe)
getmem_nbi(dest, source, bytes, pe)
getmem(dest, source, bytes, pe)
putmem_block(dest, source, bytes, pe)
putmem_nbi_block(dest, source, bytes, pe)
putmem_warp(dest, source, bytes, pe)
putmem_nbi_warp(dest, source, bytes, pe)
putmem(dest, source, bytes, pe)
putmem_nbi(dest, source, bytes, pe)
putmem_signal_nbi(dest, source, bytes, sig_addr, signal, sig_op, pe)
putmem_signal(dest, source, bytes, sig_addr, signal, sig_op, pe)
putmem_signal_nbi_block(dest, source, bytes, sig_addr, signal, sig_op, pe)
putmem_signal_block(dest, source, bytes, sig_addr, signal, sig_op, pe)
putmem_signal_nbi_warp(dest, source, bytes, sig_addr, signal, sig_op, pe)
putmem_signal_warp(dest, source, bytes, sig_addr, signal, sig_op, pe)
signal_op(sig_addr, signal, sig_op, pe)
signal_wait_until(sig_addr, cmp_, cmp_val)

参考文献:

想飞的石头​——Triton-distributed学习笔记

Triton-distributed github link