Spine-Triton编译器bug——常数乘法双重错误

论坛中的bigwolfeman大神给出了FlashAttention的SMT写法。我第一时间就下载了代码,上板运行。结果报出了下面结论1.的错误。我先是请iFLOW帮助修改,改了一圈又改回了原来的版本。然后我用最原始的“简化代码,缩小范围,替换语句”的debug方法,定位到了attention_smt.py中smt_qk_kernel矩阵乘法结果输出的一行:

    # Step 5: Scale and store output
    out = acc # * sm_scale
    out = out.to(out_ptr.dtype.element_ty)

我先后尝试了如下两种解决方法:

    # Step 5: Scale and store output
    # out = acc * sm_scale
    out = acc * (1.0 / (BLOCK_D ** 0.5))
    out = out.to(out_ptr.dtype.element_ty)
    # sm_scale,  # 1/sqrt(d) scaling factor
    sm_scale: tl.constexpr,  # 1/sqrt(d) scaling factor

似乎都可以解决这个编译报错的问题。

接下来,我在test_smt_mm.py中重现了这个问题,结果发现上述两种解法都产生了结果错误。然后就有了这篇文章。

先上结论:

  1. Triton IR中的splat语句导致向下一步linalg编译失败,报出的出错位置似乎有误。
Error: There are multiple ExtractSliceOp users
/home/bibo/work/git-spine-triton/python/examples/test_smt_mm.py:61:100: error: failed to legalize operation 'xsmt.view' that was explicitly marked illegal
            accumulator_view = smt.view(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
                                                                                                   ^
/home/bibo/work/git-spine-triton/python/examples/test_smt_mm.py:61:100: note: see current operation: %113 = "xsmt.view"(%15, %84, %7) <{micro_size = array<i32: 8, 16>, shape = array<i32: 32, 128>}> : (tensor<128x128xf32>, i64, i64) -> tensor<4x8x8x16xf32>
/home/bibo/work/git-spine-triton/python/examples/test_smt_mm.py:61:100: error: failed to legalize operation 'xsmt.view' that was explicitly marked illegal
            accumulator_view = smt.view(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N))
                                                                                                   ^
/home/bibo/work/git-spine-triton/python/examples/test_smt_mm.py:61:100: note: see current operation: %58 = "xsmt.view"(%15, %56, %7) <{micro_size = array<i32: 8, 16>, shape = array<i32: 32, 128>}> : (tensor<128x128xf32>, i64, i64) -> tensor<4x8x8x16xf32>
Traceback (most recent call last):
  File "/home/bibo/work/git-spine-triton/python/examples/test_smt_mm.py", line 141, in <module>
    C = triton_mm(A, B)
  File "/home/bibo/work/git-spine-triton/python/examples/test_smt_mm.py", line 105, in triton_mm
    mm_kernel[grid](
    ~~~~~~~~~~~~~~~^
        a,
        ^^
    ...<19 lines>...
        EVEN_K=True
        ^^^^^^^^^^^
    )
    ^
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/runtime/jit.py", line 390, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/runtime/jit.py", line 594, in run
    kernel = self.compile(src, target=target, options=options.__dict__)
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/compiler/compiler.py", line 359, in compile
    next_module = compile_ir(module, metadata)
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/compiler.py", line 303, in <lambda>
    _ttir_to_linalgdir(src, metadata)
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/compiler.py", line 32, in _ttir_to_linalgdir
    subprocess.check_call(
    ~~~~~~~~~~~~~~~~~~~~~^
        [
        ^
    ...<5 lines>...
        ]
        ^
    )
    ^
  File "/usr/lib/python3.13/subprocess.py", line 419, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/home/bibo/work/venv/spine-triton/lib/python3.13/site-packages/triton/backends/spine_triton/bin/spine-triton-opt', '/tmp/tmpqi9107ez/tt.mlir', '--triton-to-linalg-experimental', '-o', '/tmp/tmpqi9107ez/linalg.mlir']' returned non-zero exit status 1.
  1. 替换成了另外两种常数乘法的写法,虽然编译没有报错,但是运算结果有误:常数乘法没有生效。

再给出iFLOW的分析:

@perf/sm_mm Please try to tell whether there is a compiler error that failed to multiply the output with a 0.5 constant.

✦ I'll help you investigate if there's a compiler error related to multiplying output with a 0.5 constant in the SMT-based matrix
  multiplication example. Let me first look at the relevant files in the perf/smt_mm directory.

  ✔  read_file perf\smt_mm\test_smt_mm.original.py
     Read all 148 lines from perf/smt_mm/test_smt_mm.original.py

✦ Now let me check the working version to see if there's any 0.5 constant multiplication in it:

  ✔  read_file \test_smt_mm.working.smscale.nopass.py
     Read all 153 lines from perf/smt_mm/test_smt_mm.working.smscale.nopass.py

✦ Now let me check the failing version to see if there's a difference related to the 0.5 constant:

  ✔  read_file perf\smt_mm\test_smt_mm.failing.py
     Read all 153 lines from perf/smt_mm/test_smt_mm.failing.py

✦ Now let me check the directp5 version to compare:

  ✔  read_file \test_smt_mm.working.directp5.nopass.py
     Read all 153 lines from perf/smt_mm/test_smt_mm.working.directp5.nopass.py

✦ Now let me check the original MLIR file to see if there's anything related to the 0.5 constant:

  ✔  read_file ...perf\smt_mm\mm_kernel_tt.original.mlir
     Read all 65 lines from perf/smt_mm/mm_kernel_tt.original.mlir

✦ Now let me check the working MLIR file to see if there's the 0.5 constant:

  ✔  read_file \mm_kernel_tt.working.smscale.nopass.mlir
     Read all 65 lines from perf/smt_mm/mm_kernel_tt.working.smscale.nopass.mlir

✦ Now let me check the directp5 MLIR file:

  ✔  read_file \mm_kernel_tt.working.directp5.nopass.mlir
     Read all 65 lines from perf/smt_mm/mm_kernel_tt.working.directp5.nopass.mlir

✦ Let me check the failing MLIR file to see if there's any difference:

  ✔  read_file ...perf\smt_mm\mm_kernel_tt.failing.mlir
     Read all 69 lines from perf/smt_mm/mm_kernel_tt.failing.mlir

✦ Let me also check the example file that was mentioned in the editor context:

  ✔  read_file ...riton\python\examples\attention_smt.py
     Read all 553 lines from git-spine-triton/python/examples/attention_smt.py

✦ Based on my analysis of the files, I can answer your question about whether there's a compiler error that failed to multiply the output
  with a 0.5 constant:

  Yes, there appears to be a compiler issue with the SMT-based matrix multiplication. Looking at the different versions:

   1. In the original and working versions, there's no multiplication by 0.5 constant in the compiled MLIR code, even though the Python
      code has out = accumulator * sm_scale where sm_scale=0.5.

   4. In the failing version, I can see that the multiplication by the scale factor is properly handled in the MLIR:
   1    %17 = arith.mulf %arg9, %cst : f32 loc(#loc19)
   2    %18 = tt.splat %17 : f32 -> tensor<128x128xf32> loc(#loc19)
     This shows the scale factor (0.5) being multiplied with the accumulator and then stored.

   5. In the directp5 version, there's a comment indicating that the scale was hardcoded: out = accumulator * 0.5#sm_scale - showing that    
      developers had to hardcode the value instead of using the parameter.

  The issue seems to be that the compiler is not properly handling the multiplication with the scale factor in certain configurations. The   
   "failing" version appears to be the one where the scale multiplication is correctly implemented in the compiled code, while the
  original and working versions fail to apply the scale factor during compilation.

✦ This indicates a potential issue in the compiler's optimization passes or the way the sm_scale parameter is being processed in the SMT     
  kernel implementation.

@perf/smt_mm/mm_kernel_tt.working.directp5.nopass.mlir @perf/smt_mm/test_smt_mm.working.directp5.nopass.py

  ✔  read_many_files Will attempt to read and concatenate files using patterns: `perf/smt_mm/mm_kernel_tt.working.directp5.nopass.mlir`, …
     ### ReadManyFiles Result (Target Dir: `C:\work\triton\spine-triton`)

     Successfully read and concatenated content from **2 file(s)**.

     **Processed Files:**
     - `perf/smt_mm/mm_kernel_tt.working.directp5.nopass.mlir`
     - `perf/smt_mm/test_smt_mm.working.directp5.nopass.py`

✦ Based on my analysis of the files and comparing the directp5 version:

  In the Python file test_smt_mm.working.directp5.nopass.py, I can see that there's a hardcoded multiplication with 0.5:
   1 out = accumulator * 0.5#sm_scale

  However, in the corresponding MLIR file mm_kernel_tt.working.directp5.nopass.mlir, I don't see the multiplication operation being
  applied to the accumulator. The MLIR code stores the accumulator directly without any scaling:

   1 tt.store %18, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<128x128xf32>> loc(#loc20)

  Where %cst is arith.constant dense<0.000000e+00> : tensor<128x128xf32> - the zero-initialized accumulator.

  This confirms that there is indeed a compiler error where the multiplication with the 0.5 constant is not being properly compiled
  into the MLIR code. The multiplication is present in the Python source code but is not reflected in the compiled MLIR output, which        
  means the scale factor is not being applied during execution.

✦ The code comment #sm_scale after the hardcoded 0.5 suggests that the developers were aware of the issue with the sm_scale parameter        
  not working properly, so they hardcoded the value but the compiler still failed to process this multiplication correctly.

@perf/smt_mm/mm_kernel_tt.failing.mlir @perf/smt_mm/test_smt_mm.failing.py

  ✔  read_many_files Will attempt to read and concatenate files using patterns: `perf/smt_mm/mm_kernel_tt.failing.mlir`, `perf/smt_mm/tes…
     ### ReadManyFiles Result (Target Dir: `C:\work\triton\spine-triton`)

     Successfully read and concatenated content from **2 file(s)**.

     **Processed Files:**
     - `perf/smt_mm/mm_kernel_tt.failing.mlir`
     - `perf/smt_mm/test_smt_mm.failing.py`

✦ Interestingly, there's something unusual in the "failing" MLIR file. Looking at this line:

   1 %17 = arith.mulf %arg9, %cst : f32 loc(#loc19)

  This is multiplying the sm_scale parameter (%arg9) with a constant 0.0 (%cst), which would always result in 0.0! This is definitely a
  compiler bug where it's incorrectly processing the multiplication.

  Then in the next line:
   1 %18 = tt.splat %17 : f32 -> tensor<128x128xf32> loc(#loc19)

  It splats this 0.0 value across the entire tensor, and stores that instead of the properly scaled accumulator.

  The Python code correctly shows out = accumulator * sm_scale, but the compiler is generating incorrect MLIR code that multiplies the       
  scale parameter by 0 instead of multiplying the accumulator by the scale parameter.

✦ This confirms there's definitely a compiler bug in how the SMT kernel handles the multiplication with the scale factor.

Actually, when the compiler was trying to translate the Triton IR to linalg IR, there was an error “Error: There are multiple ExtractSliceOp users”. That’s why I modified the way the ms_scale was passed in to bypass that error.

✦ Thank you for that clarification! That's very helpful context. The "multiple ExtractSliceOp users" error is a known issue in the Triton
   compiler when dealing with certain tensor operations and parameter passing in SMT kernels.

  So to summarize what's happening:

   1. The original code had the correct logic: out = accumulator * sm_scale
   2. But during the Triton IR to linalg IR translation, there was an error "There are multiple ExtractSliceOp users" related to how the
      sm_scale parameter was being used
   3. To work around this compiler issue, you changed the parameter from a non-constexpr to a constexpr in one version (the directp5
      version), and hardcoded the value in another version
   6. However, the compiler is still generating incorrect code in the directp5 version where it multiplies the scale parameter by 0

  This indicates that there are multiple issues in the compiler pipeline:
   - The original issue with ExtractSliceOp when using sm_scale as a non-constexpr parameter
   - A secondary issue where the compiler generates incorrect multiplication code even when trying to workaround the original issue

✦ This is a great example of how compiler bugs can manifest in complex ways, requiring workarounds that may introduce other issues. The      
  proper fix would likely involve addressing the root cause in the Triton-to-linalg lowering pass where the ExtractSliceOp error occurs.

最后附上我收集的4个版本的Triton-C和Triton-IR代码,供SpacemiT的工程师们参考:
smt_mm-compiler_dual_error.zip (12.4 KB)

  • .original. 原始smt_mm实现
  • .failing. 第一个编译错误版本,也是bigwolfeman的原始写法
  • .working.smscale.nopass. 改写了参数传入的方法,不报错,结果错
  • .working.directp5.nopass. 直接乘以常数0.5,不报错,结果错

Wow! What an incredible find. Your post on my thread was just the start. I wish I had a way to execute this code myself for further testing. I think what you found is massive though, and probably the main issue to be found.

1 个赞

Yeah!This is how far we can go now, since the Triton compiler source code is not open sourced yet. While waiting for the dual fix, I will be working on your kernel to experience the beauty of the fused online softmax.
What an amazing adventure! Can’t wait to see what we will find, and how the SMT compiler will improve.

wow!您发现的这个bug非常关键,0.5的逻辑在ttir的时候就已经被消掉了,根本原因是triton认为tensor会每一次都产生一个副本,在out=accsm_scale 的时候triton会认为此时acc还是一个全0的操作,在这个全0的操作上做任何行为都是无意义的,在smt.parallel里进行scale操作可解决您的问题

1 个赞

smt_mm-compiler_fixed.zip (9.4 KB)

上述问题出现在3.3.0版本中,在刚刚更新的3.6.0版本中https://forum.spacemit.com/t/topic/872/8,该问题得到了解决。