GEMM计算性能PK:用Spine Triton写了一个smt-gemm Triton Kernel以及和官方的smt-mm Triton Kernel的性能对比测试

因为GEMM是现代计算的基石,所以,我这次就把目光放在了这个目标上,通过学习spine triton的官方示例程序 [spine-triton/python/examples/test_smt_mm.py]结合之前人工智能的学习经历,我心中产生了一个问题:

官方示例程序中BLOCK_SIZE_M | BLOCK_SIZE_N | BLOCK_SIZE_K | SUB_BLK_M | MICRO_M | MICRO_N | MICRO_K | EVEN_K |都是常量,GEMM只有一个256× 512@ 512× 256,

这些常量对于这个GEMM真的就是最优解吗?
这些常量对于其他GEMM也是最优解吗?

带着这些问题,我开始着手编写自己的Triton Kernel,想了解问题的答案,我的想法是这样的:

用一套“smt.*”的新语法,把原本 GPU-only 的 GEMM 搬到了 CPU 上。核心思路可以一句话概括:

把 K 维一次性 prefetched 到 L1/LLC(通过 smt.descriptor_load),然后在 M 维做二级并行:

– 外层 tl.program_id(0) 给每个 CTA 分 BLOCK_SIZE_M;

– 内层 smt.parallel 再把这 BLOCK_SIZE_M 行切成 SUB_BLK_M 行 的小块(),用 CPU 多线程并行计算;

– 每个线程内部用 8×8×16 的 micro-tile 在寄存器里累加。

我假设了Spacemit X60的L1/L2/LLC的容量,并采用了贪心算法设计了缓存感知计算模块,实现根据不同的GEMM智能分块:

把 K 维一次性 prefetched 到 L1/LLC的同时,尽量把N维也一次性 prefetched 到 L1/LLC,减少内存的数据搬运次数,于是我做了

BLOCK_SIZE_K x BLOCK_SIZE_N的L3缓存感知计算,确定BLOCK_SIZE_N的最优解

接着我做了BLOCK_SIZE_M x BLOCK_SIZE_K的L2缓存感知计算,确定BLOCK_SIZE_N的最优解,力图最大程度利用了L2

最后我做了SUB_BLK_M x BLOCK_SIZE_K的L1缓存感知计算,确定SUB_BLK_M的最优解,力图最大程度利用了L1

通过将新算法和官方的示例算法做性能比较测试,新算法缩短了计算时间,而且扩展性更好。下面是测试结果:

【表1】新算法自动推导参数组

M×K@K×N BLOCK_SIZE_M BLOCK_SIZE_N BLOCK_SIZE_K SUB_BLK_M MICRO_M MICRO_N MICRO_K EVEN_K
256× 512@ 512× 256 128 256 512 16 8 16 8 True
512× 512@ 512× 512 128 512 512 16 8 16 8 True
1024× 512@ 512×1024 128 1024 512 16 8 16 8 True

【表2】新旧算法耗时对比 (ms)

M×K@K×N old / ms new / ms speedup
256× 512@ 512× 256 15.280 9.659 1.58
512× 512@ 512× 512 57.426 30.855 1.86
1024× 512@ 512×1024 241.026 156.532 1.54
1 Like

#!/usr/bin/env python3

-- coding: utf-8 --

“”"
compare_smt_mm.py
新 / 旧 triton-smt 矩阵乘耗时对比
同时输出两张表:

  1. 各规模下新算法自动推导的参数组
  2. 新旧算法耗时对比

conda install pytorch triton -c pytorch -c nvidia # 安装依赖
python compare_smt_mm.py
“”"
import time
import torch
import triton
import triton.language as tl
import triton.language.extra.smt as smt
from triton.backends.spine_triton.driver import CPUDriver

---------- 0. 公共环境 ----------

driver = CPUDriver()
driver.set_current_arch_id(“0xA03C”)
triton.runtime.driver.set_active(driver)

===================================================================

1. 旧算法(test_smt_mm.py 固定 128×128×K)

===================================================================

@triton.jit
def mm_kernel_old(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr, SUB_BLK_M: tl.constexpr,
MICRO_M: tl.constexpr, MICRO_K: tl.constexpr, MICRO_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

a_block_ptr = tl.make_block_ptr(
    base=a_ptr, shape=[M, K], strides=[stride_am, stride_ak],
    offsets=[pid_m * BLOCK_SIZE_M, 0],
    block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], order=[1, 0])
b_block_ptr = tl.make_block_ptr(
    base=b_ptr, shape=[K, N], strides=[stride_bk, stride_bn],
    offsets=[0, pid_n * BLOCK_SIZE_N],
    block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], order=[1, 0])

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty)

if EVEN_K:
    b = smt.descriptor_load(b_block_ptr, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
    sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
    for s in smt.parallel(0, sub_num):
        a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
        acc_view = smt.view(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
        acc_view = smt.dot(a, b, acc_view)
else:
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        b = smt.descriptor_load(b_block_ptr, (k, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
        sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
        for s in smt.parallel(0, sub_num):
            a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
            acc_view = smt.view(accumulator, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
            acc_view = smt.dot(a, b, acc_view)

c_block_ptr = tl.make_block_ptr(
    base=c_ptr, shape=[M, N], strides=[stride_cm, stride_cn],
    offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
    block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], order=[1, 0])
tl.store(c_block_ptr, accumulator.to(c_ptr.dtype.element_ty), boundary_check=(0, 1))

def triton_mm_old(a: torch.Tensor, b: torch.Tensor):
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
assert a.shape[1] == b.shape[0]
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)

grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),
                     triton.cdiv(N, META["BLOCK_SIZE_N"]))
BLOCK_SIZE_K = triton.next_power_of_2(K)
mm_kernel_old[grid](
    a, b, c,
    M, N, K,
    a.stride(0), a.stride(1), b.stride(0), b.stride(1),
    c.stride(0), c.stride(1),
    BLOCK_SIZE_M=128,
    BLOCK_SIZE_N=128,
    BLOCK_SIZE_K=BLOCK_SIZE_K,
    SUB_BLK_M=32,
    MICRO_M=8, MICRO_N=16, MICRO_K=8, EVEN_K=True
)
return c

===================================================================

2. 新算法(gemm_kernelcopy.py 自动分块)

===================================================================

@triton.jit
def mm_kernel_new(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr, SUB_BLK_M: tl.constexpr,
MICRO_M: tl.constexpr, MICRO_K: tl.constexpr, MICRO_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)

a_block_ptr = tl.make_block_ptr(
    base=a_ptr, shape=[M, K], strides=[stride_am, stride_ak],
    offsets=[pid_m * BLOCK_SIZE_M, 0],
    block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], order=[1, 0])
b_block_ptr = tl.make_block_ptr(
    base=b_ptr, shape=[K, N], strides=[stride_bk, stride_bn],
    offsets=[0, pid_n * BLOCK_SIZE_N],
    block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], order=[1, 0])

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty)

if EVEN_K:
    b = smt.descriptor_load(b_block_ptr, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
    sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
    for s in smt.parallel(0, sub_num):
        a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
        acc_view = smt.view(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
        acc_view = smt.dot(a, b, acc_view)
else:
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        b = smt.descriptor_load(b_block_ptr, (k, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
        sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
        for s in smt.parallel(0, sub_num):
            a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
            acc_view = smt.view(accumulator, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
            acc_view = smt.dot(a, b, acc_view)

c_block_ptr = tl.make_block_ptr(
    base=c_ptr, shape=[M, N], strides=[stride_cm, stride_cn],
    offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
    block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], order=[1, 0])
tl.store(c_block_ptr, accumulator.to(c_ptr.dtype.element_ty), boundary_check=(0, 1))

def triton_mm_new(a: torch.Tensor, b: torch.Tensor):
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
assert a.shape[1] == b.shape[0]
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)

# ---------- 缓存层次常量 ----------
L1 = 32_768
L2 = 262_144
L3 = 2_097_152

SCALE = 1    # 用户可调,1.0 就是原算法
L1 = int(L1 * SCALE)
L2 = int(L2 * SCALE)
L3 = int(L3 * SCALE)
L1 = min(L1, 64_000)   # 硬封顶,别超过物理 L1的2倍
L2 = min(L2, 512_000)   # 硬封顶,别超过物理 L2的2倍
L3 = min(L3, 4_000_000) # 硬封顶,别超过物理 L3的2倍

BLOCK_SIZE_K = triton.next_power_of_2(K)

# ---- BLOCK_SIZE_N ----
tmp = N
while tmp * BLOCK_SIZE_K * 4 > L3:
    tmp //= 2
BLOCK_SIZE_N = triton.next_power_of_2(tmp) if tmp > 0 else 1
BLOCK_SIZE_N = max(BLOCK_SIZE_N, 64)

# ---- BLOCK_SIZE_M ----
tmp = M
while tmp * BLOCK_SIZE_K * 4 > L2:
    tmp //= 2
BLOCK_SIZE_M = triton.next_power_of_2(tmp) if tmp > 0 else 1
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 64)

# ---- SUB_BLK_M ----
MICRO_M = 8
tmp = BLOCK_SIZE_M
while tmp * BLOCK_SIZE_K * 4 > L1:
    tmp //= 2
SUB_BLK_M = triton.next_power_of_2(tmp) if tmp > 0 else 1
SUB_BLK_M = max(SUB_BLK_M, MICRO_M)


MICRO_N = 16
MICRO_K = 8
EVEN_K = True

# ---------- 把调试信息打包 ----------
info = {
    'BLOCK_SIZE_M': BLOCK_SIZE_M,
    'BLOCK_SIZE_N': BLOCK_SIZE_N,
    'BLOCK_SIZE_K': BLOCK_SIZE_K,
    'SUB_BLK_M'   : SUB_BLK_M,
    'MICRO_M'     : MICRO_M,
    'MICRO_N'     : MICRO_N,
    'MICRO_K'     : MICRO_K,
    'EVEN_K'      : EVEN_K
}

grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),
                     triton.cdiv(N, META["BLOCK_SIZE_N"]))
mm_kernel_new[grid](
    a, b, c,
    M, N, K,
    a.stride(0), a.stride(1), b.stride(0), b.stride(1),
    c.stride(0), c.stride(1),
    BLOCK_SIZE_M=BLOCK_SIZE_M,
    BLOCK_SIZE_N=BLOCK_SIZE_N,
    BLOCK_SIZE_K=BLOCK_SIZE_K,
    SUB_BLK_M=SUB_BLK_M,
    MICRO_M=MICRO_M, MICRO_N=MICRO_N, MICRO_K=MICRO_K, EVEN_K=EVEN_K
)
return c, info

===================================================================

3. 公共工具:计时 + 正确性

===================================================================

def benchmark(fn, a, b, ref, n_warm=3, n_iter=10):
for _ in range(n_warm):
fn(a, b)
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(n_iter):
c = fn(a, b)
torch.cuda.synchronize() if torch.cuda.is_available() else None
t1 = time.perf_counter()
torch.testing.assert_close(c, ref, atol=1e-2, rtol=0)
return (t1 - t0) / n_iter * 1e3 # ms

===================================================================

4. 主流程:多规模对比

===================================================================

def main():
shapes = [
(256, 256, 512),
(512, 512, 512),
(1024, 1024, 512),

]

# ---------- 表1:新算法参数组 ----------
print("\n【表1】新算法自动推导参数组")
print("|    M×K@K×N    | BLOCK_SIZE_M | BLOCK_SIZE_N | BLOCK_SIZE_K | SUB_BLK_M | MICRO_M | MICRO_N | MICRO_K | EVEN_K |")
print("|---------------|--------------|--------------|--------------|-----------|---------|---------|---------|--------|")
param_records = []
for M, N, K in shapes:
    A = torch.randn(M, K, dtype=torch.float32, device="cpu")
    B = torch.randn(K, N, dtype=torch.float32, device="cpu")
    _, info = triton_mm_new(A, B)          # 只取参数
    print(f"| {M:4}×{K:4}@{K:4}×{N:4} "
          f"| {info['BLOCK_SIZE_M']:12} "
          f"| {info['BLOCK_SIZE_N']:12} "
          f"| {info['BLOCK_SIZE_K']:12} "
          f"| {info['SUB_BLK_M']:9} "
          f"| {info['MICRO_M']:7} "
          f"| {info['MICRO_N']:7} "
          f"| {info['MICRO_K']:7} "
          f"| {str(info['EVEN_K']):6} |")
    param_records.append(info)

# ---------- 表2:耗时对比 ----------
print("\n【表2】新旧算法耗时对比 (ms)")
print("|    M×K@K×N    |  old / ms  |  new / ms  | speedup |")
print("|---------------|------------|------------|---------|")
for idx, (M, N, K) in enumerate(shapes):
    A = torch.randn(M, K, dtype=torch.float32, device="cpu")
    B = torch.randn(K, N, dtype=torch.float32, device="cpu")
    ref = torch.mm(A, B)

    t_old = benchmark(triton_mm_old, A, B, ref)
    t_new = benchmark(lambda a, b: triton_mm_new(a, b)[0], A, B, ref)
    print(f"| {M:4}×{K:4}@{K:4}×{N:4} | {t_old:10.3f} | {t_new:10.3f} | {t_old/t_new:7.2f} |")

if name == “main”:
main()

3 Likes
【表1】新算法自动推导参数组
|    M×K@K×N    | BLOCK_SIZE_M | BLOCK_SIZE_N | BLOCK_SIZE_K | SUB_BLK_M | MICRO_M | MICRO_N | MICRO_K | EVEN_K |
|---------------|--------------|--------------|--------------|-----------|---------|---------|---------|--------|
|  256× 512@ 512× 256 |          128 |          256 |          512 |        16 |       8 |      16 |       8 | True   |
|  512× 512@ 512× 512 |          128 |          512 |          512 |        16 |       8 |      16 |       8 | True   |
| 1024× 512@ 512×1024 |          128 |         1024 |          512 |        16 |       8 |      16 |       8 | True   |

【表2】新旧算法耗时对比 (ms)
|    M×K@K×N    |  old / ms  |  new / ms  | speedup |
|---------------|------------|------------|---------|
|  256× 512@ 512× 256 |      7.636 |     12.115 |    0.63 |
|  512× 512@ 512× 512 |     23.712 |     23.720 |    1.00 |
| 1024× 512@ 512×1024 |    101.862 |     79.378 |    1.28 |

【表2】新旧算法耗时对比 (ms)
|    M×K@K×N    |  old / ms  |  new / ms  | speedup |
|---------------|------------|------------|---------|
|  256× 512@ 512× 256 |      6.800 |     12.259 |    0.55 |
|  512× 512@ 512× 512 |     23.108 |     17.976 |    1.29 |
| 1024× 512@ 512×1024 |    100.251 |     83.259 |    1.20 |

【表2】新旧算法耗时对比 (ms)
|    M×K@K×N    |  old / ms  |  new / ms  | speedup |
|---------------|------------|------------|---------|
|  256× 512@ 512× 256 |      7.141 |     10.037 |    0.71 |
|  512× 512@ 512× 512 |     22.289 |     17.421 |    1.28 |
| 1024× 512@ 512×1024 |    100.347 |     82.316 |    1.22 |

【表2】新旧算法耗时对比 (ms)
|    M×K@K×N    |  old / ms  |  new / ms  | speedup |
|---------------|------------|------------|---------|
|  256× 512@ 512× 256 |      6.084 |     12.281 |    0.50 |
|  512× 512@ 512× 512 |     22.466 |     17.518 |    1.28 |
| 1024× 512@ 512×1024 |    108.295 |     82.950 |    1.31 |

【表2】新旧算法耗时对比 (ms)
|    M×K@K×N    |  old / ms  |  new / ms  | speedup |
|---------------|------------|------------|---------|
|  256× 512@ 512× 256 |      6.820 |     12.203 |    0.56 |
|  512× 512@ 512× 512 |     22.317 |     17.223 |    1.30 |
| 1024× 512@ 512×1024 |    105.669 |     84.125 |    1.26 |

compare_smt_mm.zip (2.6 KB)

@compare_smt_mm.py what is the difference between
mm_kernel_old and mm_kernel_new?

  ✔  read_many_files Will attempt to read and concatenate files us…
     ### ReadManyFiles Result (Target Dir:
     `C:\work\triton\spine-triton`)

     Successfully read and concatenated content from
     **1 file(s)**.

     **Processed Files:**
     - `compare_smt_mm.py`

✦ Looking at the code, I can see that mm_kernel_old and
  mm_kernel_new are actually identical in their implementation -
  they have the exact same function body. The key difference is
  in how they are called and the parameters passed to them:

  `mm_kernel_old` (triton_mm_old function):
   - Uses fixed block sizes: BLOCK_SIZE_M=128, BLOCK_SIZE_N=128
   - Fixed SUB_BLK_M=32
   - Based on the comment "test_smt_mm.py 固定 128×128×K" - this
     is a fixed-size approach

  `mm_kernel_new` (triton_mm_new function):
   - Uses automatically calculated block sizes based on cache     
     hierarchy (L1, L2, L3)
   - Dynamically computes:
     - BLOCK_SIZE_M based on L2 cache constraints
     - BLOCK_SIZE_N based on L3 cache constraints
     - SUB_BLK_M based on L1 cache constraints
   - The comment indicates this is from "gemm_kernelcopy.py       
     自动分块" (automatic blocking)

✦ The actual kernel implementation remains the same, but the new  
  approach provides better performance by adapting block sizes to 
   the problem dimensions and cache characteristics rather than   
  using fixed 128×128 blocks.
1 Like

不清楚与你的测试环境有什么差异,导致我的运行时间比你的短许多。
我的测试是在Bianbu v3系统下,没有连接显示器,用SSH远程登录进行的。

我是用musebook M1笔记本电脑测试的。你的测试数据真棒!

1 Like