# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
from typing import Optional

from torch._inductor.autoheuristic.autoheuristic_utils import (
    AHContext,
    AHMetadata,
    Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
    LearnedHeuristicDecision,
)


class PadMMH200(LearnedHeuristicDecision):

    def __init__(self) -> None:
        self.choices: list[Choice] = []
        self.fill_choices()

    def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
        return (
            metadata.name == self.get_name()
            and metadata.shared_memory == 232448
            and str(metadata.device_capa) == "(9, 0)"
        )

    def get_confidence_threshold(self) -> float:
        return 0.7710651828298887

    def get_choice(self, idx: int) -> Optional[str]:
        if idx < len(self.choices):
            return self.choices[idx]
        return None

    def fill_choices(self) -> None:
        self.choices.append('orig')
        self.choices.append('pad')

    def get_name(self) -> str:
        return 'pad_mm'

    def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
        if str(context.get_value('mat1_innermost_needs_padding')) != 'True':
            if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
                if context.get_value('n_padded_length') <= 0.5:
                    if str(context.get_value('prepadded_mat1')) != 'True':
                        if str(context.get_value('using_tf32')) != 'False':
                            return [(1.000, 0)]
                        else:
                            if context.get_value('mat1_stride_0') <= 3584.0:
                                return [(1.000, 0)]
                            else:
                                if context.get_value('mat2_stride_0') <= 3584.0:
                                    return [(1.000, 0)]
                                else:
                                    return [(0.528, 0), (0.472, 1)]
                    else:
                        if context.get_value('n') <= 2304.0:
                            if context.get_value('m*k') <= 25198592.0:
                                if context.get_value('arith_intensity') <= 1103.9319458007812:
                                    return [(1.000, 0)]
                                else:
                                    return [(0.885, 0), (0.115, 1)]
                            else:
                                if context.get_value('m*k') <= 25688064.0:
                                    return [(1.000, 1)]
                                else:
                                    return [(0.771, 0), (0.229, 1)]
                        else:
                            if str(context.get_value('using_tf32')) != 'False':
                                if context.get_value('m') <= 27825.0:
                                    return [(0.948, 0), (0.052, 1)]
                                else:
                                    return [(0.855, 0), (0.145, 1)]
                            else:
                                if context.get_value('mat2_stride_0') <= 3584.0:
                                    return [(1.000, 0)]
                                else:
                                    return [(0.917, 1), (0.083, 0)]
                else:
                    if context.get_value('m') <= 1823.5:
                        if str(context.get_value('n_multiple_2')) != 'False':
                            if context.get_value('k*n') <= 7859200.0:
                                return [(0.600, 0), (0.400, 1)]
                            else:
                                return [(1.000, 0)]
                        else:
                            if context.get_value('k/(m*n)') <= 0.00040277576772496104:
                                return [(1.000, 1)]
                            else:
                                return [(0.800, 1), (0.200, 0)]
                    else:
                        if context.get_value('n') <= 3602.0:
                            return [(0.800, 1), (0.200, 0)]
                        else:
                            return [(1.000, 1)]
            else:
                if str(context.get_value('using_tf32')) != 'False':
                    if str(context.get_value('n_multiple_16')) != 'False':
                        if str(context.get_value('k_multiple_2')) != 'True':
                            if context.get_value('arith_intensity') <= 744.8332214355469:
                                return [(0.600, 0), (0.400, 1)]
                            else:
                                return [(1.000, 1)]
                        else:
                            if context.get_value('m*n') <= 8912896.0:
                                if context.get_value('m*k') <= 5934080.0:
                                    return [(0.800, 0), (0.200, 1)]
                                else:
                                    return [(1.000, 0)]
                            else:
                                return [(1.000, 1)]
                    else:
                        return [(1.000, 1)]
                else:
                    return [(1.000, 0)]
        else:
            if context.get_value('arith_intensity') <= 895.8767395019531:
                if str(context.get_value('m_multiple_2')) != 'False':
                    if context.get_value('mat1_stride_1') <= 3421.0:
                        if str(context.get_value('using_tf32')) != 'False':
                            if context.get_value('mat2_stride_1') <= 10706.5:
                                if context.get_value('mat2_stride_0') <= 1024.5:
                                    return [(0.816, 1), (0.184, 0)]
                                else:
                                    return [(1.000, 1)]
                            else:
                                if str(context.get_value('k_multiple_2')) != 'True':
                                    return [(0.905, 1), (0.095, 0)]
                                else:
                                    return [(1.000, 0)]
                        else:
                            if str(context.get_value('prepadded_mat2')) != 'True':
                                if str(context.get_value('mat2_innermost_needs_padding')) != 'False':
                                    return [(1.000, 0)]
                                else:
                                    return [(0.932, 0), (0.068, 1)]
                            else:
                                if context.get_value('arith_intensity') <= 742.1241760253906:
                                    return [(0.889, 0), (0.111, 1)]
                                else:
                                    return [(0.765, 1), (0.235, 0)]
                    else:
                        if context.get_value('n') <= 1216.0:
                            if str(context.get_value('using_tf32')) != 'True':
                                if context.get_value('mat1_stride_1') <= 5567.0:
                                    return [(0.896, 0), (0.104, 1)]
                                else:
                                    return [(0.999, 0), (0.001, 1)]
                            else:
                                return [(1.000, 1)]
                        else:
                            if str(context.get_value('using_tf32')) != 'False':
                                return [(1.000, 1)]
                            else:
                                return [(1.000, 0)]
                else:
                    if str(context.get_value('using_tf32')) != 'False':
                        return [(1.000, 1)]
                    else:
                        if context.get_value('mat2_stride_1') <= 2688.0:
                            return [(1.000, 0)]
                        else:
                            return [(0.500, 0), (0.500, 1)]
            else:
                if str(context.get_value('using_tf32')) != 'False':
                    return [(1.000, 1)]
                else:
                    if str(context.get_value('mat2_innermost_needs_padding')) != 'True':
                        return [(1.000, 0)]
                    else:
                        return [(0.800, 0), (0.200, 1)]
