Spine-Triton编译器bug——tt.mlir to linalg.mlir

今天继续学习bigwolfeman大神的FlashAttention实现。在发现并绕过了两处语法错误后(其中一处被确认为Triton编译器错误),又有了新的编译报错。

SMT Attention Demo
============================================================

Input: Q=torch.Size([128, 64]), K=torch.Size([128, 64]), V=torch.Size([128, 64])

2. Testing smt_attention_fused_kernel (full attention)...
/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/language/standard.py:290:36: error: failed to legalize operation 'linalg.fill' that was explicitly marked illegal
    return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
                                   ^
/home/bibo/work/git-spine-triton/python/examples/attention_smt.py:269:37: note: called from
        l_new = alpha * l_i + tl.sum(p_ij, axis=1)
                                    ^
/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/language/standard.py:290:36: note: see current operation: 
%19 = "linalg.fill"(%0, %18) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg73: f32, %arg74: f32):
  "linalg.yield"(%arg73) : (f32) -> ()
}) : (f32, tensor<64xf32>) -> tensor<64xf32>
    return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
                                   ^
<unknown>:0: error: failed to legalize operation 'linalg.fill' that was explicitly marked illegal
<unknown>:0: note: see current operation: 
%11 = "linalg.fill"(%0, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg75: f32, %arg76: f32):
  "linalg.yield"(%arg75) : (f32) -> ()
}) : (f32, tensor<64x1xf32>) -> tensor<64x1xf32>
<unknown>:0: error: failed to legalize operation 'linalg.fill' that was explicitly marked illegal
<unknown>:0: note: see current operation: 
%11 = "linalg.fill"(%0, %10) <{operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg75: f32, %arg76: f32):
  "linalg.yield"(%arg75) : (f32) -> ()
}) : (f32, tensor<64x1xf32>) -> tensor<64x1xf32>
Traceback (most recent call last):
  File "/home/bibo/work/git-spine-triton/python/examples/attention_smt.py", line 539, in <module>
    out = run_smt_attention(Q, K, V)
  File "/home/bibo/work/git-spine-triton/python/examples/attention_smt.py", line 493, in run_smt_attention
    smt_attention_fused_kernel[grid](
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        Q, K, V, out,                     # Input and output tensors
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<14 lines>...
        MICRO_K=MICRO_K,      # SMT microkernel K
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/runtime/jit.py", line 390, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/runtime/jit.py", line 594, in run
    kernel = self.compile(src, target=target, options=options.__dict__)
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/compiler/compiler.py", line 359, in compile
    next_module = compile_ir(module, metadata)
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/compiler.py", line 303, in <lambda>
    _ttir_to_linalgdir(src, metadata)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/compiler.py", line 32, in _ttir_to_linalgdir
    subprocess.check_call(
    ~~~~~~~~~~~~~~~~~~~~~^
        [
        ^
    ...<5 lines>...
        ]
        ^
    )
    ^
  File "/usr/lib/python3.13/subprocess.py", line 419, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/bin/spine-triton-opt', '/tmp/tmps7yxhndw/tt.mlir', '--triton-to-linalg-experimental', '-o', '/tmp/tmps7yxhndw/linalg.mlir']' returned non-zero exit status 1.

之前出错的pass在Triton-C到Triton-IR,这次出错的pass在Triton-IR到Linalg-IR。暂时还看不懂,得请SpacemiT的工程师们看看。
源码和Triton-IR附在此处:
smt_attention_fused_kernel_tt.zip (7.8 KB)