# flake8: noqa: B950
# fmt: off
# This file was generated by AutoHeuristic. Do not modify it manually!
# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/
from typing import Optional

from torch._inductor.autoheuristic.autoheuristic_utils import (
    AHContext,
    AHMetadata,
    Choice,
)
from torch._inductor.autoheuristic.learnedheuristic_interface import (
    LearnedHeuristicDecision,
)


class PadMMA100(LearnedHeuristicDecision):

    def __init__(self) -> None:
        self.choices: list[Choice] = []
        self.fill_choices()

    def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
        return (
            metadata.name == self.get_name()
            and metadata.shared_memory == 166912
            and str(metadata.device_capa) == "(8, 0)"
        )

    def get_confidence_threshold(self) -> float:
        return 0.9294871794871795

    def get_choice(self, idx: int) -> Optional[str]:
        if idx < len(self.choices):
            return self.choices[idx]
        return None

    def fill_choices(self) -> None:
        self.choices.append('orig')
        self.choices.append('pad')

    def get_name(self) -> str:
        return 'pad_mm'

    def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
        if str(context.get_value('mat1_innermost_needs_padding')) != 'False':
            if context.get_value('arith_intensity') <= 880.0238037109375:
                if str(context.get_value('m_multiple_2')) != 'True':
                    if context.get_value('n') <= 652.0:
                        if context.get_value('m') <= 2022.0:
                            if str(context.get_value('using_tf32')) != 'False':
                                return [(0.579, 1), (0.421, 0)]
                            else:
                                return [(1.000, 0)]
                        else:
                            return [(1.000, 0)]
                    else:
                        if context.get_value('m*k') <= 107278336.0:
                            if str(context.get_value('using_tf32')) != 'False':
                                if context.get_value('m') <= 23691.0:
                                    return [(0.993, 1), (0.007, 0)]
                                else:
                                    return [(0.840, 1), (0.160, 0)]
                            else:
                                if context.get_value('arith_intensity') <= 793.3185424804688:
                                    return [(0.958, 0), (0.042, 1)]
                                else:
                                    return [(0.792, 1), (0.208, 0)]
                        else:
                            if context.get_value('arith_intensity') <= 795.6242370605469:
                                if context.get_value('mat2_stride_1') <= 2048.5:
                                    return [(0.929, 0), (0.071, 1)]
                                else:
                                    return [(1.000, 1)]
                            else:
                                if context.get_value('arith_intensity') <= 796.3460388183594:
                                    return [(0.957, 1), (0.043, 0)]
                                else:
                                    return [(0.778, 0), (0.222, 1)]
                else:
                    if context.get_value('mat2_stride_0') <= 2432.0:
                        if str(context.get_value('k_multiple_2')) != 'False':
                            if context.get_value('n') <= 1024.5:
                                if str(context.get_value('prepadded_mat1')) != 'False':
                                    return [(0.580, 0), (0.420, 1)]
                                else:
                                    return [(0.986, 0), (0.014, 1)]
                            else:
                                if context.get_value('mat1_stride_0') <= 5125.0:
                                    return [(0.551, 0), (0.449, 1)]
                                else:
                                    return [(0.916, 0), (0.084, 1)]
                        else:
                            if context.get_value('mat2_align_size') <= 6.0:
                                if str(context.get_value('using_tf32')) != 'True':
                                    return [(1.000, 0)]
                                else:
                                    return [(0.800, 0), (0.200, 1)]
                            else:
                                if context.get_value('mat2_stride_1') <= 3820.0:
                                    return [(0.986, 1), (0.014, 0)]
                                else:
                                    return [(0.532, 0), (0.468, 1)]
                    else:
                        if str(context.get_value('using_tf32')) != 'False':
                            if context.get_value('m*n') <= 5244928.0:
                                if str(context.get_value('k_multiple_2')) != 'True':
                                    return [(0.971, 1), (0.029, 0)]
                                else:
                                    return [(0.646, 1), (0.354, 0)]
                            else:
                                if context.get_value('k/(m*n)') <= 9.468618827668251e-06:
                                    return [(0.800, 1), (0.200, 0)]
                                else:
                                    return [(1.000, 1)]
                        else:
                            if context.get_value('mat1_stride_1') <= 1288.0:
                                if context.get_value('k') <= 5717.0:
                                    return [(0.983, 0), (0.017, 1)]
                                else:
                                    return [(0.800, 0), (0.200, 1)]
                            else:
                                return [(0.588, 0), (0.412, 1)]
            else:
                if str(context.get_value('using_tf32')) != 'False':
                    if context.get_value('n') <= 2640.0:
                        if context.get_value('m_padded_length') <= 1.5:
                            if context.get_value('mat2_stride_1') <= 6021.0:
                                return [(1.000, 1)]
                            else:
                                if context.get_value('mat2_stride_1') <= 6069.0:
                                    return [(0.900, 1), (0.100, 0)]
                                else:
                                    return [(1.000, 1)]
                        else:
                            if str(context.get_value('m_multiple_2')) != 'False':
                                if context.get_value('m*k') <= 24444928.0:
                                    return [(0.593, 0), (0.407, 1)]
                                else:
                                    return [(0.923, 0), (0.077, 1)]
                            else:
                                return [(1.000, 1)]
                    else:
                        if context.get_value('m*k') <= 404182016.0:
                            if str(context.get_value('mat2_innermost_needs_padding')) != 'False':
                                if context.get_value('m*k') <= 12328960.0:
                                    return [(0.732, 1), (0.268, 0)]
                                else:
                                    return [(0.989, 1), (0.011, 0)]
                            else:
                                if context.get_value('m*k') <= 389028864.0:
                                    return [(0.998, 1), (0.002, 0)]
                                else:
                                    return [(0.922, 1), (0.078, 0)]
                        else:
                            if context.get_value('m*n') <= 137631744.0:
                                if context.get_value('m*k') <= 405715968.0:
                                    return [(0.611, 1), (0.389, 0)]
                                else:
                                    return [(0.946, 1), (0.054, 0)]
                            else:
                                return [(0.714, 0), (0.286, 1)]
                else:
                    if context.get_value('mat2_stride_0') <= 3902.5:
                        return [(0.941, 0), (0.059, 1)]
                    else:
                        return [(0.583, 1), (0.417, 0)]
        else:
            if context.get_value('n_padded_length') <= 0.5:
                if str(context.get_value('mat2_innermost_needs_padding')) != 'False':
                    if str(context.get_value('k_multiple_2')) != 'True':
                        if context.get_value('arith_intensity') <= 884.5185852050781:
                            if context.get_value('arith_intensity') <= 743.931884765625:
                                return [(1.000, 0)]
                            else:
                                return [(0.583, 0), (0.417, 1)]
                        else:
                            return [(1.000, 1)]
                    else:
                        if context.get_value('k/(m*n)') <= 0.00023481580865336582:
                            return [(0.900, 0), (0.100, 1)]
                        else:
                            return [(1.000, 0)]
                else:
                    if str(context.get_value('using_tf32')) != 'False':
                        if context.get_value('m*k') <= 93734912.0:
                            if context.get_value('mat1_stride_0') <= 1344.0:
                                if context.get_value('n') <= 7168.0:
                                    return [(1.000, 0)]
                                else:
                                    return [(0.970, 0), (0.030, 1)]
                            else:
                                if context.get_value('m') <= 22883.5:
                                    return [(0.977, 0), (0.023, 1)]
                                else:
                                    return [(0.800, 0), (0.200, 1)]
                        else:
                            if context.get_value('arith_intensity') <= 1914.3681030273438:
                                if str(context.get_value('prepadded_mat1')) != 'False':
                                    return [(0.981, 0), (0.019, 1)]
                                else:
                                    return [(0.995, 0), (0.005, 1)]
                            else:
                                return [(1.000, 0)]
                    else:
                        if str(context.get_value('prepadded_mat1')) != 'False':
                            if context.get_value('mat2_stride_1') <= 256.5:
                                if context.get_value('m') <= 5880.5:
                                    return [(1.000, 0)]
                                else:
                                    return [(0.800, 0), (0.200, 1)]
                            else:
                                if context.get_value('m*k') <= 6318080.0:
                                    return [(0.618, 1), (0.382, 0)]
                                else:
                                    return [(0.880, 0), (0.120, 1)]
                        else:
                            if context.get_value('k/(m*n)') <= 0.0009747986623551697:
                                if context.get_value('k/(m*n)') <= 0.0006397514371201396:
                                    return [(0.951, 0), (0.049, 1)]
                                else:
                                    return [(0.857, 0), (0.143, 1)]
                            else:
                                return [(1.000, 0)]
            else:
                if str(context.get_value('using_tf32')) != 'False':
                    if str(context.get_value('n_multiple_2')) != 'False':
                        if context.get_value('m') <= 2024.0:
                            if context.get_value('mat2_stride_0') <= 1629.0:
                                if context.get_value('k*n') <= 1288704.0:
                                    return [(0.600, 0), (0.400, 1)]
                                else:
                                    return [(0.982, 0), (0.018, 1)]
                            else:
                                if context.get_value('mat1_stride_0') <= 768.0:
                                    return [(0.619, 0), (0.381, 1)]
                                else:
                                    return [(0.812, 1), (0.188, 0)]
                        else:
                            if context.get_value('m*n') <= 5803008.0:
                                return [(0.500, 0), (0.500, 1)]
                            else:
                                if context.get_value('mat2_stride_1') <= 896.0:
                                    return [(1.000, 1)]
                                else:
                                    return [(0.818, 1), (0.182, 0)]
                    else:
                        if context.get_value('mat2_stride_1') <= 2560.0:
                            if context.get_value('num_dims_needs_padding') <= 1.5:
                                if context.get_value('mat2_stride_1') <= 896.0:
                                    return [(1.000, 1)]
                                else:
                                    return [(0.857, 1), (0.143, 0)]
                            else:
                                return [(0.727, 1), (0.273, 0)]
                        else:
                            return [(0.667, 1), (0.333, 0)]
                else:
                    if context.get_value('k/(m*n)') <= 0.00015462777810171247:
                        return [(0.857, 0), (0.143, 1)]
                    else:
                        if context.get_value('k/(m*n)') <= 0.0019917909521609545:
                            return [(1.000, 0)]
                        else:
                            return [(0.900, 0), (0.100, 1)]
