Triton 是一个SPMD(单程序多数据)编程模型开发语言+编译基础设施,目前以GPU Kernel开发为主
它本质上有三层:
① Triton Language (Python DSL )
Triton Language 代码位于 python/triton/language 中,可以分为两部分:
1,Triton Language Operation, 可以在kernel内部当作函数来调用,会通过visit_Call 在PythonAST中进行处理。
2,在kernel内部的各种运算符,比如+,-,*,/等,一般通过visit_BinOP在PythonAST中处理。
Triton向量加法内核如下:
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
@triton.jit
def add_kernel(x_ptr, y_ptr,output_ptr,n_elements,BLOCK_SIZE:tl.constexpr):
pid=tl.program_id(axis=0)
block_start=pid*BLOCK_SIZE
offsets=block_start+tl.arange(0,BLOCK_SIZE)
mask=offsets<n_elements
x=tl.load(x_ptr+offsets,mask=mask)
y=tl.load(y_ptr+offsets,mask=mask)
output=x+y
tl.store(output_ptr+offsets,output,mask=mask)
这段代码中:
-
x_ptr, y_ptr, output_ptr这些指针指向输入/输出 GPU 数组的起始位置。调用 Triton 内核时,传递的任何 PyTorch(或 NumPy CuPy)张量都会转换为指向其数据的指针。 -
n_elements是向量的总长度。我们传递这个参数是为了让内核知道边界。 -
BLOCK_SIZE: tl.constexprBLOCK_SIZE 是一个编译时常量,它定义了每个程序实例(代码块)处理的元素数量。我们将选择一个 BLOCK_SIZE 值(例如 1024),使得代码块中的线程能够以向量化的方式一次处理 1024 个元素。 -
在内核内部,
tl.program_id(axis=0)它给出了当前程序实例在网格第 0 维上的唯一索引。我们启动一个一维程序网格来进行向量加法运算。 -
我们计算的
offsets范围是从块的起始索引到block_start + BLOCK_SIZE - 1。tl.arange(0, BLOCK_SIZE)创建一个块局部索引向量,索引范围为 0,1,…,BLOCK_SIZE-1。通过添加block_start,我们得到此内核实例将处理的数组中的绝对索引。 -
我们创建一个
mask布尔向量,用于指示哪些偏移量在边界内offset < n_elements。对于超出数组长度的任何索引,此掩码都将为 false(例如,如果 N 不是 BLOCK_SIZE 的倍数,则最后一个块将有一些超出范围的偏移量)。Triton 使用掩码来安全地处理内存访问,而无需显式分支。 -
tl.load从给定地址(指针)的内存中读取数据。我们执行此操作tl.load(x_ptr + offsets, mask=mask),其底层会针对这些位置发出向量化加载指令,并且对于任何值为maskfalse 的位置,它实际上不会加载数据(或者会替换为一个虚拟值,以避免非法内存访问)。类似地,对于y。 -
然后我们执行逐元素加法
result = x + y。由于 Triton 的向量化运算,此加法操作一次性作用于整个元素块。其概念类似于对数据切片执行 NumPy 数组加法,但这里是在 GPU 块中并行执行的。 -
最后,
tl.store(output_ptr + offsets, result, mask=mask)将所有有效索引对应的结果写回全局内存中的输出数组。掩码确保我们只写入边界内的数据。 -
因为每个程序实例处理 BLOCK_SIZE 个元素,并且它们都是并行运行的,所以整个向量会在一次内核启动中被添加进去。
内核装饰器: 装饰器@triton.jit用于定义 Triton 内核。@triton.jit 装饰器的工作原理是遍历所提供的 Python 函数的抽象语法树 (AST),从而使用常见的 SSA 构造算法动态生成 Triton-IR。
② Triton Intermediate Representation(TTIR)
内核编译首先遍历被装饰的 Python 函数的抽象语法树 (AST),生成 Triton 中间表示 (Triton-IR)。Triton-IR 是一种未经优化的、与机器无关的中间表示。
Triton language到ttir的流程:
1. JIT入口:调用编译器
当用户写的 Triton kernel 第一次被调用时,@triton.jit 装饰器会触发 JIT:
kernel = self.compile(...)
这里的 self.compile 会启动整个编译流程,其关键步骤是调用:
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
ast_to_ttir 将 Python AST 转换为 Triton 的 TTIR(Triton Intermediate Representation)。
2. AST → TTIR:
ast_to_ttir 内部使用 visitor 模式 遍历 Python AST:
ret = super().visit(node)
这是关键入口点,编译器会跳入 Python AST 节点,再返回 Triton 自定义的 IR 节点。
3. 处理二元运算
对于向量加法:
output = x + y
Python AST 会产生一个 BinOp 节点,于是 visitor 会进入:
visit_BinOp(self, node)
此函数在:
python/triton/language/core.py
接着,
+ 会被wrapper包装器映射到 Triton 的 builtin 运算 add。
然后调用 semantic 层做语义分析
核心逻辑在:
python/triton/language/semantic.py
例如 float + float 会进入:
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
-
Triton tensor = builder 创建的 fadd 节点
-
fadd是浮点加法运算
最后,builder.create_fadd → C++ IR 构建
create_fadd 对应到 C++ 侧:
return self.create<arith::AddFOp>(lhs, rhs);
也就是 MLIR 中的 arith.addf 操作。
到这里,Python 代码 x + y 已经变成了 TTIR 的 fadd 指令。
对于 Triton 向量加法内核,示例 Triton IR 代码片段如下:
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
%0 = tt.get_program_id x : i32 loc(#loc1)
%c1024_i32 = arith.constant 1024 : i32 loc(#loc2)
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc2)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc3)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc4)
%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc4)
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc5)
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc5)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc7)
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc8)
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc8)
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc9)
%13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc10)
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc11)
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc11)
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc12)
tt.return loc(#loc13)
} loc(#loc)
} loc(#loc)
③ 后端(以运行在Nvidia gpu为例)
Triton 通过 MLIR 构建的 Pass pipeline,把 Python Kernel 自动优化并映射为 GPU 专用的 TTGIR,再降级到 LLVM IR,生成 PTX,最后由 ptxas 编译为可运行的 cubin,并通过 CUDA runtime 执行。架构(Turing/Ampere/Hopper)会影响 Pass 的选择,使得 Triton 能针对不同的 NVIDIA GPU 自动生成最佳代码。
SpacemiT Triton:进迭时空为什么要做 Triton?
虽然 Triton 有众多优势,但是:
-
Triton 大多数优化与语义是 GPU-oriented
-
基于x86的TritonCPU项目性能表现一般
-
计算分块、访存层级、访存逻辑均以GPU风格体现,线程优化模型并不适合CPU
-
缺乏真正适配CPU及统一内存的Triton Kernel调度实践
因此,几乎所有 Triton 生态项目最终都只能跑 GPU。
所以进迭要实现一件更大的事情:
让 Triton 写的算子能在 RISC-V AI CPU 上高效跑起来。
致力于推动 RISC-V + AI + Trtion 计算生态的优化与普及。
下一篇预告
第二篇:RISC-V AI CPU Triton软件栈