# this is a benchmark which multiplies square matrices with maximum block size
# to check the performance of tl.dot operation

import torch
import triton
import triton.language as tl
import benchmark

@triton.jit
def bare_fused_mm_softmax(X, Y, Z, M, N, K, BLOCK_SIZE: tl.constexpr):
    pid_x = tl.program_id(0)  # block row id
    pid_y = tl.program_id(1)  # block column id

    offs_x = pid_x * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_y = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

    x = tl.load(X + offs_x[:, None] * K + offs_y[None, :])
    y = tl.load(Y + offs_x[:, None] * N + offs_y[None, :])

    z = tl.dot(x, y)
    z = bare_softmax(z)

    tl.store(Z + offs_x[:, None] * N + offs_y[None, :], z)

@triton.jit
def bare_softmax(row):
    # Subtract maximum for numerical stability
    row_minus_max = row - tl.max(row, axis=0)
    # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    return numerator / denominator

# Wrap the torch kernel as a PyTorch operator
from torch.library import triton_op, wrap_triton
@triton_op("mylib::triton_fused_mm_softmax", mutates_args={})
def triton_fused_mm_softmax(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    c = torch.empty_like(a)
    n_elements = a[0].numel()
    wrap_triton(bare_fused_mm_softmax)[(1,)](a, b, c, n_elements, n_elements, n_elements, BLOCK_SIZE=n_elements)
    return c

@benchmark.measure()
def bench_fused_matmul_softmax(N, provider):
    device = 'cpu'
    dtype = torch.float32
    a = torch.randn((N, N), device=device, dtype=dtype)
    b = torch.randn((N, N), device=device, dtype=dtype)
    c = torch.empty((N, N), device=device, dtype=dtype)
    if provider == 'torch' or provider == 'test':
        c_ref = torch.softmax(torch.matmul(a, b), axis=0)
    if provider == 'triton' or provider == 'test':
        # bare_fused_mm_softmax[(1,)](a, b, c, N, N, N, N)
        c = torch.ops.mylib.triton_fused_mm_softmax.default(a, b)
        if provider == 'test':
            torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=0)

@benchmark.measure()
def bench_fused_matmul_softmax_triton(N):
    provider = 'triton'
    device = 'cpu'
    dtype = torch.float32
    a = torch.randn((N, N), device=device, dtype=dtype)
    b = torch.randn((N, N), device=device, dtype=dtype)
    c = torch.empty((N, N), device=device, dtype=dtype)
    if provider == 'triton':
        # bare_fused_mm_softmax[(1,)](a, b, c, N, N, N, N)
        c = torch.ops.mylib.triton_fused_mm_softmax.default(a, b)

@benchmark.measure()
def bench_fused_matmul_softmax_torch(N):
    provider = 'torch'
    device = 'cpu'
    dtype = torch.float32
    a = torch.randn((N, N), device=device, dtype=dtype)
    b = torch.randn((N, N), device=device, dtype=dtype)
    c = torch.empty((N, N), device=device, dtype=dtype)
    if provider == 'torch':
        c_ref = torch.softmax(torch.matmul(a, b), axis=0)

# @torch.compile(fullgraph=True)
# def wrapper_triton(a, b, c, N):
#     bare_fused_mm_softmax[(1,)](a, b, c, N, N, N, N)

if __name__ == "__main__":
    benchmark.select_cpu_backend()
    # Pytorch Profiler
    from torch.profiler import profile, record_function, ProfilerActivity
    with profile(activities=[ProfilerActivity.CPU], profile_memory=True) as prof:
        for X in [2**i for i in range(7, 10, 1)]:
            for provider in ['test', 'torch', 'triton']:
                with record_function("operator_test: " + provider + "_" + str(X)):
                    bench_fused_matmul_softmax(X, provider)

    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
    prof.export_chrome_trace("profiler_export_trace.json")

    # Python Profiler
    import cProfile
    profiler = cProfile.Profile()

    # N = 512
    # device = 'cpu'
    # dtype = torch.float32

    profiler.enable()
    # a = torch.randn((N, N), device=device, dtype=dtype)
    # b = torch.randn((N, N), device=device, dtype=dtype)
    # c = torch.empty((N, N), device=device, dtype=dtype)
    # c = torch.ops.mylib.triton_fused_mm_softmax.default(a, b)
    # bare_fused_mm_softmax[(1,)](a, b, c, N, N, N, N)
    # wrapper_triton(a, b, c, N)
    bench_fused_matmul_softmax_triton(512)
    profiler.disable()
    # profiler.print_stats()
    profiler.dump_stats("profiler_export_trace.single.triton.prof")

    # d = torch.empty((N, N), device=device, dtype=dtype)
    profiler.enable()
    # d = torch.softmax(torch.matmul(a, b), axis=0)
    bench_fused_matmul_softmax_torch(512)
    profiler.disable()
    # profiler.print_stats()
    profiler.dump_stats("profiler_export_trace.single.torch.prof")
