今天继续学习bigwolfeman大神的FlashAttention实现。在发现并绕过了两处语法错误后(其中一处被确认为Triton编译器错误),又有了新的编译报错。
SMT Attention Demo
============================================================
Input: Q=torch.Size([128, 64]), K=torch.Size([128, 64]), V=torch.Size([128, 64])
2. Testing smt_attention_fused_kernel (full attention)...
/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/language/standard.py:290:36: error: failed to legalize operation 'linalg.fill' that was explicitly marked illegal
return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
^
/home/bibo/work/git-spine-triton/python/examples/attention_smt.py:269:37: note: called from
l_new = alpha * l_i + tl.sum(p_ij, axis=1)
^
/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/language/standard.py:290:36: note: see current operation:
%19 = "linalg.fill"(%0, %18) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg73: f32, %arg74: f32):
"linalg.yield"(%arg73) : (f32) -> ()
}) : (f32, tensor<64xf32>) -> tensor<64xf32>
return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
^
<unknown>:0: error: failed to legalize operation 'linalg.fill' that was explicitly marked illegal
<unknown>:0: note: see current operation:
%11 = "linalg.fill"(%0, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg75: f32, %arg76: f32):
"linalg.yield"(%arg75) : (f32) -> ()
}) : (f32, tensor<64x1xf32>) -> tensor<64x1xf32>
<unknown>:0: error: failed to legalize operation 'linalg.fill' that was explicitly marked illegal
<unknown>:0: note: see current operation:
%11 = "linalg.fill"(%0, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg75: f32, %arg76: f32):
"linalg.yield"(%arg75) : (f32) -> ()
}) : (f32, tensor<64x1xf32>) -> tensor<64x1xf32>
Traceback (most recent call last):
File "/home/bibo/work/git-spine-triton/python/examples/attention_smt.py", line 539, in <module>
out = run_smt_attention(Q, K, V)
File "/home/bibo/work/git-spine-triton/python/examples/attention_smt.py", line 493, in run_smt_attention
smt_attention_fused_kernel[grid](
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
Q, K, V, out, # Input and output tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<14 lines>...
MICRO_K=MICRO_K, # SMT microkernel K
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/runtime/jit.py", line 390, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/runtime/jit.py", line 594, in run
kernel = self.compile(src, target=target, options=options.__dict__)
File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/compiler/compiler.py", line 359, in compile
next_module = compile_ir(module, metadata)
File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/compiler.py", line 303, in <lambda>
_ttir_to_linalgdir(src, metadata)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/compiler.py", line 32, in _ttir_to_linalgdir
subprocess.check_call(
~~~~~~~~~~~~~~~~~~~~~^
[
^
...<5 lines>...
]
^
)
^
File "/usr/lib/python3.13/subprocess.py", line 419, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/bin/spine-triton-opt', '/tmp/tmps7yxhndw/tt.mlir', '--triton-to-linalg-experimental', '-o', '/tmp/tmps7yxhndw/linalg.mlir']' returned non-zero exit status 1.
之前出错的pass在Triton-C到Triton-IR,这次出错的pass在Triton-IR到Linalg-IR。暂时还看不懂,得请SpacemiT的工程师们看看。
源码和Triton-IR附在此处:
smt_attention_fused_kernel_tt.zip (7.8 KB)