# mypy: allow-untyped-defs
import functools
import itertools
import logging
from typing import cast, List, Tuple

import sympy

import torch
from torch._inductor.select_algorithm import realize_inputs
from torch._inductor.virtualized import V

from .. import config as inductor_config
from ..runtime.runtime_utils import next_power_of_2
from ..utils import ceildiv as cdiv


log = logging.getLogger(__name__)


def triton_config(num_stages, num_warps, **kwargs):
    from triton import Config

    return Config(kwargs, num_stages=num_stages, num_warps=num_warps)


def filtered_configs(
    m: int,
    n: int,
    k: int,
    configs: List[Tuple[int, int, int, int, int]],
    has_int8_tensor=False,
):
    """Heuristic to shrink configs when they are bigger than the input size"""

    min_block_size = 16
    # block_k=16 seems to be causing issues
    # see: https://github.com/triton-lang/triton/issues/2156#issuecomment-1695897424
    min_block_size_k = 32 if has_int8_tensor else 16
    m = max(
        next_power_of_2(
            V.graph.sizevars.size_hint(
                m, fallback=torch._inductor.config.unbacked_symint_fallback  # type: ignore[arg-type]
            )
        ),
        min_block_size,
    )
    n = max(
        next_power_of_2(
            V.graph.sizevars.size_hint(
                n, fallback=torch._inductor.config.unbacked_symint_fallback  # type: ignore[arg-type]
            )
        ),
        min_block_size,
    )
    k = max(
        next_power_of_2(
            V.graph.sizevars.size_hint(
                k, fallback=torch._inductor.config.unbacked_symint_fallback  # type: ignore[arg-type]
            )
        ),
        min_block_size_k,
    )
    used = set()
    for block_m, block_n, block_k, num_stages, num_warps in configs:
        # shrink configs for small sizes
        block_m = max(min(block_m, m), min_block_size)
        block_n = max(min(block_n, n), min_block_size)
        block_k = max(min(block_k, k), min_block_size_k)
        # each warp computes 16x16 tile = 256
        num_warps = min(num_warps, block_m * block_n // 256)
        if torch.version.hip:
            for matrix_instr_nonkdim in [0, 16]:
                if matrix_instr_nonkdim != 0 and (
                    block_m % matrix_instr_nonkdim != 0
                    or block_n % matrix_instr_nonkdim != 0
                ):
                    #  block_m and block_n must be a multiple of matrix_instr_nonkdim
                    continue
                if (
                    block_m,
                    block_n,
                    block_k,
                    num_stages,
                    num_warps,
                    matrix_instr_nonkdim,
                ) not in used:
                    used.add(
                        (
                            block_m,
                            block_n,
                            block_k,
                            num_stages,
                            num_warps,
                            matrix_instr_nonkdim,
                        )
                    )
                    yield triton_config(
                        BLOCK_M=block_m,
                        BLOCK_N=block_n,
                        BLOCK_K=block_k,
                        num_stages=num_stages,
                        num_warps=num_warps,
                        matrix_instr_nonkdim=matrix_instr_nonkdim,
                    )
        else:
            if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used:
                used.add((block_m, block_n, block_k, num_stages, num_warps, 0))
                yield triton_config(
                    BLOCK_M=block_m,
                    BLOCK_N=block_n,
                    BLOCK_K=block_k,
                    num_stages=num_stages,
                    num_warps=num_warps,
                )


# List of dictionaries to store the kernel configs. Configs that evaluate to true
# will be utilised on the target platform. The configs are as follows:
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
mm_kernel_configs = (
    [
        {"config": (32, 32, 16, 1, 2), "cond": True},
        {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
        {"config": (32, 64, 32, 5, 8), "cond": True},
        {"config": (64, 32, 32, 5, 8), "cond": True},
        {"config": (64, 32, 128, 5, 4), "cond": True},
        {"config": (64, 64, 16, 2, 4), "cond": True},
        {"config": (64, 64, 32, 2, 4), "cond": True},
        {"config": (64, 64, 64, 3, 8), "cond": True},
        {"config": (64, 64, 128, 5, 4), "cond": True},
        {"config": (64, 128, 32, 3, 4), "cond": True},
        {"config": (64, 128, 32, 4, 8), "cond": True},
        {"config": (64, 128, 64, 3, 4), "cond": True},
        {"config": (64, 128, 128, 4, 4), "cond": True},
        {"config": (128, 64, 32, 3, 4), "cond": True},
        {"config": (128, 64, 32, 4, 8), "cond": True},
        {"config": (128, 128, 32, 2, 8), "cond": True},
        {"config": (128, 128, 32, 3, 4), "cond": True},
        {"config": (128, 128, 64, 3, 4), "cond": True},
        {"config": (128, 128, 64, 5, 8), "cond": True},
    ]
    if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
    else [
        {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
        for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
            [16, 32, 64, 128, 256], repeat=3
        )
        for num_stages in [1, 2, 3, 4, 5]
        for num_warps in [2, 4, 8]
    ]
)

# these are only used in tuned_mm when AutoHeuristic is enabled
# the idea is that when AutoHeuristic collects data to learn a heuristic, more configs are autotuned
# when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10
# which saves compilation time (since less configs are autotuned) and potentially increase performance
# because the learned heuristic might predict a config that is not part mm_configs
extra_mm_kernel_configs = [
    {"config": (16, 32, 16, 3, 2), "cond": True},
    {"config": (16, 32, 32, 4, 2), "cond": True},
    {"config": (16, 32, 32, 5, 2), "cond": True},
    {"config": (64, 64, 128, 3, 4), "cond": True},
    {"config": (128, 64, 32, 2, 2), "cond": True},
    {"config": (128, 64, 64, 3, 8), "cond": True},
    {"config": (128, 64, 128, 4, 8), "cond": True},
    {"config": (128, 128, 32, 4, 4), "cond": True},
    {"config": (128, 128, 64, 3, 8), "cond": True},
    {"config": (128, 128, 64, 5, 4), "cond": True},
]

int8_mm_kernel_configs = [
    {"config": (64, 64, 32, 2, 4), "cond": True},
    {"config": (64, 128, 32, 3, 4), "cond": True},
    {"config": (128, 64, 32, 3, 4), "cond": True},
    {"config": (64, 128, 32, 4, 8), "cond": True},
    {"config": (128, 64, 32, 4, 8), "cond": True},
    {"config": (64, 32, 32, 5, 8), "cond": True},
    {"config": (32, 64, 32, 5, 8), "cond": True},
    {"config": (128, 128, 32, 2, 8), "cond": True},
    {"config": (64, 64, 64, 3, 8), "cond": True},
    # {"config": (32, 32, 128, 2, 4), "cond": True},
    # {"config": (64, 64, 16, 2, 4), "cond": True},
    # {"config": (32, 32, 16, 1, 2), "cond": True},
    {"config": (128, 256, 128, 3, 8), "cond": torch.version.hip is None},
    {"config": (256, 128, 128, 3, 8), "cond": torch.version.hip is None},
]

# Mixed precision kernel configs for small sizes of m for mm's like (16, 8192) x (8192, 8192).
mixed_mm_kernel_configs_small_m = [
    {"config": (16, 128, 256, 3, 4), "cond": True},
    {"config": (16, 128, 256, 5, 8), "cond": True},
]

mixed_mm_kernel_configs = (
    mm_kernel_configs + mixed_mm_kernel_configs_small_m
    if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
    else mm_kernel_configs
)

scaled_mm_kernel_configs = [
    {"config": (128, 256, 32, 3, 8), "cond": True},
    {"config": (256, 128, 32, 3, 8), "cond": True},
    {"config": (256, 64, 32, 4, 4), "cond": True},
    {"config": (64, 256, 32, 4, 4), "cond": True},
    {"config": (128, 128, 32, 4, 4), "cond": True},
    {"config": (128, 64, 32, 4, 4), "cond": True},
    {"config": (64, 128, 32, 4, 4), "cond": True},
    {"config": (128, 32, 32, 4, 4), "cond": True},
    {"config": (64, 32, 32, 5, 2), "cond": True},
    {"config": (256, 128, 128, 3, 8), "cond": True},
    {"config": (256, 64, 128, 4, 4), "cond": True},
    {"config": (64, 256, 128, 4, 4), "cond": True},
    {"config": (128, 128, 128, 4, 4), "cond": True},
    {"config": (128, 64, 64, 4, 4), "cond": True},
    {"config": (64, 128, 64, 4, 4), "cond": True},
    {"config": (128, 32, 64, 4, 4), "cond": True},
    {"config": (64, 32, 64, 5, 2), "cond": True},
    {"config": (16, 32, 32, 2, 2), "cond": True},
    {"config": (16, 64, 32, 2, 2), "cond": True},
    {"config": (16, 128, 32, 2, 4), "cond": True},
    {"config": (16, 256, 32, 2, 4), "cond": True},
    {"config": (16, 32, 64, 2, 2), "cond": True},
    {"config": (16, 64, 64, 2, 2), "cond": True},
    {"config": (16, 128, 64, 2, 4), "cond": True},
    {"config": (16, 256, 64, 2, 4), "cond": True},
    {"config": (32, 32, 32, 2, 2), "cond": True},
    {"config": (32, 64, 32, 2, 2), "cond": True},
    {"config": (32, 128, 32, 2, 4), "cond": True},
    {"config": (32, 256, 32, 2, 4), "cond": True},
    {"config": (32, 32, 64, 2, 2), "cond": True},
    {"config": (32, 64, 64, 2, 2), "cond": True},
    {"config": (32, 128, 64, 2, 4), "cond": True},
    {"config": (32, 256, 64, 2, 4), "cond": True},
    {"config": (16, 32, 32, 3, 2), "cond": True},
    {"config": (16, 64, 32, 3, 2), "cond": True},
    {"config": (16, 128, 32, 3, 4), "cond": True},
    {"config": (16, 256, 32, 3, 4), "cond": True},
    {"config": (16, 32, 64, 3, 2), "cond": True},
    {"config": (16, 64, 64, 3, 2), "cond": True},
    {"config": (16, 128, 64, 3, 4), "cond": True},
    {"config": (16, 256, 64, 3, 4), "cond": True},
    {"config": (32, 32, 32, 3, 2), "cond": True},
    {"config": (32, 64, 32, 3, 2), "cond": True},
    {"config": (32, 128, 32, 3, 4), "cond": True},
    {"config": (32, 256, 32, 3, 4), "cond": True},
    {"config": (32, 32, 64, 3, 2), "cond": True},
    {"config": (32, 64, 64, 3, 2), "cond": True},
    {"config": (32, 128, 64, 3, 4), "cond": True},
    {"config": (32, 256, 64, 3, 4), "cond": True},
    {"config": (16, 32, 32, 4, 2), "cond": True},
    {"config": (16, 64, 32, 4, 2), "cond": True},
    {"config": (16, 128, 32, 4, 4), "cond": True},
    {"config": (16, 256, 32, 4, 4), "cond": True},
    {"config": (16, 32, 64, 4, 2), "cond": True},
    {"config": (16, 64, 64, 4, 2), "cond": True},
    {"config": (16, 128, 64, 4, 4), "cond": True},
    {"config": (16, 256, 64, 4, 4), "cond": True},
    {"config": (32, 32, 32, 4, 2), "cond": True},
    {"config": (32, 64, 32, 4, 2), "cond": True},
    {"config": (32, 128, 32, 4, 4), "cond": True},
    {"config": (32, 256, 32, 4, 4), "cond": True},
    {"config": (32, 32, 64, 4, 2), "cond": True},
    {"config": (32, 64, 64, 4, 2), "cond": True},
    {"config": (32, 128, 64, 4, 4), "cond": True},
    {"config": (32, 256, 64, 4, 4), "cond": True},
    {"config": (16, 32, 32, 5, 2), "cond": True},
    {"config": (16, 64, 32, 5, 2), "cond": True},
    {"config": (16, 128, 32, 5, 4), "cond": True},
    {"config": (16, 256, 32, 5, 4), "cond": True},
    {"config": (16, 32, 64, 5, 2), "cond": True},
    {"config": (16, 64, 64, 5, 2), "cond": True},
    {"config": (16, 128, 64, 5, 4), "cond": True},
    {"config": (16, 256, 64, 5, 4), "cond": True},
    {"config": (32, 32, 32, 5, 2), "cond": True},
    {"config": (32, 64, 32, 5, 2), "cond": True},
    {"config": (32, 128, 32, 5, 4), "cond": True},
    {"config": (32, 256, 32, 5, 4), "cond": True},
    {"config": (32, 32, 64, 5, 2), "cond": True},
    {"config": (32, 64, 64, 5, 2), "cond": True},
    {"config": (32, 128, 64, 5, 4), "cond": True},
    {"config": (32, 256, 64, 5, 4), "cond": True},
    {"config": (16, 32, 32, 6, 2), "cond": True},
    {"config": (16, 64, 32, 6, 2), "cond": True},
    {"config": (16, 128, 32, 6, 4), "cond": True},
    {"config": (16, 256, 32, 6, 4), "cond": True},
    {"config": (16, 32, 64, 6, 2), "cond": True},
    {"config": (16, 64, 64, 6, 2), "cond": True},
    {"config": (16, 128, 64, 6, 4), "cond": True},
    {"config": (16, 256, 64, 6, 4), "cond": True},
    {"config": (32, 32, 32, 6, 2), "cond": True},
    {"config": (32, 64, 32, 6, 2), "cond": True},
    {"config": (32, 128, 32, 6, 4), "cond": True},
    {"config": (32, 256, 32, 6, 4), "cond": True},
    {"config": (32, 32, 64, 6, 2), "cond": True},
    {"config": (32, 64, 64, 6, 2), "cond": True},
    {"config": (32, 128, 64, 6, 4), "cond": True},
    {"config": (32, 256, 64, 6, 4), "cond": True},
]


# Create filtered list of configs based on cond evaluation
mm_platform_configs = tuple(
    cast(Tuple[int, int, int, int, int], config["config"])
    for config in mm_kernel_configs
    if config["cond"]
)
extra_mm_platform_configs = tuple(
    cast(Tuple[int, int, int, int, int], config["config"])
    for config in extra_mm_kernel_configs
    if config["cond"]
)
int8_platform_configs = tuple(
    cast(Tuple[int, int, int, int, int], config["config"])
    for config in int8_mm_kernel_configs
    if config["cond"]
)
mixed_mm_platform_configs = tuple(
    cast(Tuple[int, int, int, int, int], config["config"])
    for config in mixed_mm_kernel_configs
    if config["cond"]
)
scaled_mm_platform_configs = tuple(
    cast(Tuple[int, int, int, int, int], config["config"])
    for config in scaled_mm_kernel_configs
    if config["cond"]
)

# On ROCm convert num_stages to 0 to enable software pipelining
if torch.version.hip:
    mm_platform_configs = tuple(
        (config[0], config[1], config[2], 0, config[4])
        for config in mm_platform_configs
    )
    extra_mm_platform_configs = tuple(
        (config[0], config[1], config[2], 0, config[4])
        for config in extra_mm_platform_configs
    )
    int8_platform_configs = tuple(
        (config[0], config[1], config[2], 0, config[4])
        for config in mm_platform_configs
    )
    mixed_mm_platform_configs = tuple(
        (config[0], config[1], config[2], 0, config[4])
        for config in mixed_mm_platform_configs
    )
    scaled_mm_platform_configs = tuple(
        (config[0], config[1], config[2], 0, config[4])
        for config in scaled_mm_platform_configs
    )

mm_configs = functools.partial(
    filtered_configs,
    configs=mm_platform_configs,
)

extra_mm_configs = functools.partial(
    filtered_configs,
    configs=extra_mm_platform_configs,
)

int8_mm_configs = functools.partial(
    filtered_configs,
    configs=int8_platform_configs,
)

mixed_mm_configs = functools.partial(
    filtered_configs,
    configs=mixed_mm_platform_configs,
)

scaled_mm_configs = functools.partial(
    filtered_configs,
    configs=scaled_mm_platform_configs,
)


def mm_grid(m, n, meta):
    """
    The CUDA grid size for matmul triton templates.
    """
    return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)


def acc_type(dtype):
    if dtype in (torch.float16, torch.bfloat16):
        return "tl.float32"
    return f"tl.{dtype}".replace("torch.", "")


def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
    """
    Common options to matmul triton templates.
    """
    even_k_symbolic = (
        # it isn't worth guarding on this
        sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
        == config.kwargs["BLOCK_K"]
    )
    allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
        not inductor_config.force_same_precision
        or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0)
    )
    return dict(
        GROUP_M=8,
        EVEN_K=even_k_symbolic,
        ALLOW_TF32=allow_tf32,
        ACC_TYPE=acc_type(layout.dtype),
        B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
        num_stages=config.num_stages,
        num_warps=config.num_warps,
        **config.kwargs,
    )


def mm_args(
    mat1,
    mat2,
    *others,
    layout=None,
    out_dtype=None,
    use_4x2_dim=False,
    mat2_transposed=False,
):
    """
    Common arg processing for mm,bmm,addmm,etc
    """
    mat1, mat2 = realize_inputs(mat1, mat2)
    *b1, m, k1 = mat1.get_size()
    if mat2_transposed:
        *b2, n, k2 = mat2.get_size()
    else:
        *b2, k2, n = mat2.get_size()
    b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)]
    if use_4x2_dim:
        k2 = k2 * 2
    k = V.graph.sizevars.guard_equals(k1, k2)
    if layout is None:
        from torch._inductor.ir import FixedLayout

        if out_dtype is None:
            out_dtype = mat1.get_dtype()

        layout = FixedLayout(
            mat1.get_device(),
            out_dtype,
            [*b, m, n],
        )
    else:
        assert out_dtype is None, "out_dtype is ignored if layout is specified."
    from ..lowering import expand

    others = [realize_inputs(expand(x, layout.size)) for x in others]

    return [m, n, k, layout, mat1, mat2, *others]


def addmm_epilogue(dtype, alpha, beta):
    def epilogue(acc, bias):
        if alpha != 1:
            acc = V.ops.mul(acc, V.ops.constant(alpha, dtype))
        if beta != 1:
            bias = V.ops.mul(bias, V.ops.constant(beta, dtype))
        return V.ops.add(acc, bias)

    return epilogue
