#!/usr/bin/env python3
-- coding: utf-8 --
“”"
compare_smt_mm.py
新 / 旧 triton-smt 矩阵乘耗时对比
同时输出两张表:
- 各规模下新算法自动推导的参数组
- 新旧算法耗时对比
conda install pytorch triton -c pytorch -c nvidia # 安装依赖
python compare_smt_mm.py
“”"
import time
import torch
import triton
import triton.language as tl
import triton.language.extra.smt as smt
from triton.backends.spine_triton.driver import CPUDriver
---------- 0. 公共环境 ----------
driver = CPUDriver()
driver.set_current_arch_id(“0xA03C”)
triton.runtime.driver.set_active(driver)
===================================================================
1. 旧算法(test_smt_mm.py 固定 128×128×K)
===================================================================
@triton.jit
def mm_kernel_old(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr, SUB_BLK_M: tl.constexpr,
MICRO_M: tl.constexpr, MICRO_K: tl.constexpr, MICRO_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
a_block_ptr = tl.make_block_ptr(
base=a_ptr, shape=[M, K], strides=[stride_am, stride_ak],
offsets=[pid_m * BLOCK_SIZE_M, 0],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], order=[1, 0])
b_block_ptr = tl.make_block_ptr(
base=b_ptr, shape=[K, N], strides=[stride_bk, stride_bn],
offsets=[0, pid_n * BLOCK_SIZE_N],
block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], order=[1, 0])
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty)
if EVEN_K:
b = smt.descriptor_load(b_block_ptr, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
for s in smt.parallel(0, sub_num):
a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
acc_view = smt.view(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
acc_view = smt.dot(a, b, acc_view)
else:
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
b = smt.descriptor_load(b_block_ptr, (k, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
for s in smt.parallel(0, sub_num):
a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
acc_view = smt.view(accumulator, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
acc_view = smt.dot(a, b, acc_view)
c_block_ptr = tl.make_block_ptr(
base=c_ptr, shape=[M, N], strides=[stride_cm, stride_cn],
offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], order=[1, 0])
tl.store(c_block_ptr, accumulator.to(c_ptr.dtype.element_ty), boundary_check=(0, 1))
def triton_mm_old(a: torch.Tensor, b: torch.Tensor):
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
assert a.shape[1] == b.shape[0]
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(N, META["BLOCK_SIZE_N"]))
BLOCK_SIZE_K = triton.next_power_of_2(K)
mm_kernel_old[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M=128,
BLOCK_SIZE_N=128,
BLOCK_SIZE_K=BLOCK_SIZE_K,
SUB_BLK_M=32,
MICRO_M=8, MICRO_N=16, MICRO_K=8, EVEN_K=True
)
return c
===================================================================
2. 新算法(gemm_kernelcopy.py 自动分块)
===================================================================
@triton.jit
def mm_kernel_new(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr, SUB_BLK_M: tl.constexpr,
MICRO_M: tl.constexpr, MICRO_K: tl.constexpr, MICRO_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
a_block_ptr = tl.make_block_ptr(
base=a_ptr, shape=[M, K], strides=[stride_am, stride_ak],
offsets=[pid_m * BLOCK_SIZE_M, 0],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], order=[1, 0])
b_block_ptr = tl.make_block_ptr(
base=b_ptr, shape=[K, N], strides=[stride_bk, stride_bn],
offsets=[0, pid_n * BLOCK_SIZE_N],
block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], order=[1, 0])
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty)
if EVEN_K:
b = smt.descriptor_load(b_block_ptr, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
for s in smt.parallel(0, sub_num):
a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
acc_view = smt.view(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
acc_view = smt.dot(a, b, acc_view)
else:
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
b = smt.descriptor_load(b_block_ptr, (k, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N))
sub_num = tl.cdiv(min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m), SUB_BLK_M)
for s in smt.parallel(0, sub_num):
a = smt.descriptor_load(a_block_ptr, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K))
acc_view = smt.view(accumulator, (s * SUB_BLK_M, k), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
acc_view = smt.dot(a, b, acc_view)
c_block_ptr = tl.make_block_ptr(
base=c_ptr, shape=[M, N], strides=[stride_cm, stride_cn],
offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], order=[1, 0])
tl.store(c_block_ptr, accumulator.to(c_ptr.dtype.element_ty), boundary_check=(0, 1))
def triton_mm_new(a: torch.Tensor, b: torch.Tensor):
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
assert a.shape[1] == b.shape[0]
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# ---------- 缓存层次常量 ----------
L1 = 32_768
L2 = 262_144
L3 = 2_097_152
SCALE = 1 # 用户可调,1.0 就是原算法
L1 = int(L1 * SCALE)
L2 = int(L2 * SCALE)
L3 = int(L3 * SCALE)
L1 = min(L1, 64_000) # 硬封顶,别超过物理 L1的2倍
L2 = min(L2, 512_000) # 硬封顶,别超过物理 L2的2倍
L3 = min(L3, 4_000_000) # 硬封顶,别超过物理 L3的2倍
BLOCK_SIZE_K = triton.next_power_of_2(K)
# ---- BLOCK_SIZE_N ----
tmp = N
while tmp * BLOCK_SIZE_K * 4 > L3:
tmp //= 2
BLOCK_SIZE_N = triton.next_power_of_2(tmp) if tmp > 0 else 1
BLOCK_SIZE_N = max(BLOCK_SIZE_N, 64)
# ---- BLOCK_SIZE_M ----
tmp = M
while tmp * BLOCK_SIZE_K * 4 > L2:
tmp //= 2
BLOCK_SIZE_M = triton.next_power_of_2(tmp) if tmp > 0 else 1
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 64)
# ---- SUB_BLK_M ----
MICRO_M = 8
tmp = BLOCK_SIZE_M
while tmp * BLOCK_SIZE_K * 4 > L1:
tmp //= 2
SUB_BLK_M = triton.next_power_of_2(tmp) if tmp > 0 else 1
SUB_BLK_M = max(SUB_BLK_M, MICRO_M)
MICRO_N = 16
MICRO_K = 8
EVEN_K = True
# ---------- 把调试信息打包 ----------
info = {
'BLOCK_SIZE_M': BLOCK_SIZE_M,
'BLOCK_SIZE_N': BLOCK_SIZE_N,
'BLOCK_SIZE_K': BLOCK_SIZE_K,
'SUB_BLK_M' : SUB_BLK_M,
'MICRO_M' : MICRO_M,
'MICRO_N' : MICRO_N,
'MICRO_K' : MICRO_K,
'EVEN_K' : EVEN_K
}
grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(N, META["BLOCK_SIZE_N"]))
mm_kernel_new[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
SUB_BLK_M=SUB_BLK_M,
MICRO_M=MICRO_M, MICRO_N=MICRO_N, MICRO_K=MICRO_K, EVEN_K=EVEN_K
)
return c, info
===================================================================
3. 公共工具:计时 + 正确性
===================================================================
def benchmark(fn, a, b, ref, n_warm=3, n_iter=10):
for _ in range(n_warm):
fn(a, b)
torch.cuda.synchronize() if torch.cuda.is_available() else None
t0 = time.perf_counter()
for _ in range(n_iter):
c = fn(a, b)
torch.cuda.synchronize() if torch.cuda.is_available() else None
t1 = time.perf_counter()
torch.testing.assert_close(c, ref, atol=1e-2, rtol=0)
return (t1 - t0) / n_iter * 1e3 # ms
===================================================================
4. 主流程:多规模对比
===================================================================
def main():
shapes = [
(256, 256, 512),
(512, 512, 512),
(1024, 1024, 512),
]
# ---------- 表1:新算法参数组 ----------
print("\n【表1】新算法自动推导参数组")
print("| M×K@K×N | BLOCK_SIZE_M | BLOCK_SIZE_N | BLOCK_SIZE_K | SUB_BLK_M | MICRO_M | MICRO_N | MICRO_K | EVEN_K |")
print("|---------------|--------------|--------------|--------------|-----------|---------|---------|---------|--------|")
param_records = []
for M, N, K in shapes:
A = torch.randn(M, K, dtype=torch.float32, device="cpu")
B = torch.randn(K, N, dtype=torch.float32, device="cpu")
_, info = triton_mm_new(A, B) # 只取参数
print(f"| {M:4}×{K:4}@{K:4}×{N:4} "
f"| {info['BLOCK_SIZE_M']:12} "
f"| {info['BLOCK_SIZE_N']:12} "
f"| {info['BLOCK_SIZE_K']:12} "
f"| {info['SUB_BLK_M']:9} "
f"| {info['MICRO_M']:7} "
f"| {info['MICRO_N']:7} "
f"| {info['MICRO_K']:7} "
f"| {str(info['EVEN_K']):6} |")
param_records.append(info)
# ---------- 表2:耗时对比 ----------
print("\n【表2】新旧算法耗时对比 (ms)")
print("| M×K@K×N | old / ms | new / ms | speedup |")
print("|---------------|------------|------------|---------|")
for idx, (M, N, K) in enumerate(shapes):
A = torch.randn(M, K, dtype=torch.float32, device="cpu")
B = torch.randn(K, N, dtype=torch.float32, device="cpu")
ref = torch.mm(A, B)
t_old = benchmark(triton_mm_old, A, B, ref)
t_new = benchmark(lambda a, b: triton_mm_new(a, b)[0], A, B, ref)
print(f"| {M:4}×{K:4}@{K:4}×{N:4} | {t_old:10.3f} | {t_new:10.3f} | {t_old/t_new:7.2f} |")
if name == “main”:
main()