#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/MultiTensorApply.cuh>
#include <ATen/native/cuda/Pow.cuh>
#include <type_traits>
#include <utility>

namespace at::native {

enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };

// Validates the dtype configuration for mixed-precision fused Adam/AdamW.
//
// Currently the only supported configuration is:
//   params/grads: float32
//   optimizer states (exp_avg, exp_avg_sq, ...): bfloat16
//
// This specific configuration (fp32 params + bf16 optimizer states) has been
// validated end-to-end in large-scale training runs (e.g. DeepSeek-V3 671B)
// and is the only one for which training convergence has been demonstrated.
// Additional mixed-precision configurations (e.g. float16 states) can be
// enabled here once convergence is verified for those as well.
//
// Only [0] is checked because within-list dtype homogeneity is guaranteed by
// _check_tensors_share_device_and_dtype (with skip_cross_list_dtype_check)
// and the Python-side grouping in
// _group_tensors_by_first_tensors_device_and_dtype.
inline void validate_mixed_precision_dtypes(
    at::TensorList params,
    at::TensorList grads,
    at::TensorList exp_avgs,
    at::TensorList exp_avg_sqs,
    const char* op_name) {
  TORCH_CHECK(
      params[0].scalar_type() == at::kFloat,
      op_name,
      " requires float32 params, got ",
      params[0].scalar_type());
  TORCH_CHECK(
      grads[0].scalar_type() == at::kFloat,
      op_name,
      " requires float32 grads, got ",
      grads[0].scalar_type());
  TORCH_CHECK(
      exp_avgs[0].scalar_type() == at::kBFloat16,
      op_name,
      " requires bfloat16 optimizer states, got ",
      exp_avgs[0].scalar_type());
  TORCH_CHECK(
      exp_avg_sqs[0].scalar_type() == at::kBFloat16,
      op_name,
      " requires bfloat16 optimizer states, got ",
      exp_avg_sqs[0].scalar_type());
}

inline void validate_mixed_precision_dtypes(
    at::TensorList params,
    at::TensorList grads,
    at::TensorList exp_avgs,
    at::TensorList exp_avg_sqs,
    at::TensorList max_exp_avg_sqs,
    const char* op_name) {
  validate_mixed_precision_dtypes(
      params, grads, exp_avgs, exp_avg_sqs, op_name);
  TORCH_CHECK(
      max_exp_avg_sqs[0].scalar_type() == at::kBFloat16,
      op_name,
      " requires bfloat16 max_exp_avg_sqs, got ",
      max_exp_avg_sqs[0].scalar_type());
}

namespace {

constexpr uint8_t kParamIdx = 0;
constexpr uint8_t kGradIdx = 1;
constexpr uint8_t kExpAvgIdx = 2;
constexpr uint8_t kExpAvgSqIdx = 3;
constexpr uint8_t kMaxExpAvgSqIdx = 4;

template <
    typename scalar_type,
    typename opmath_t,
    int depth,
    ADAM_MODE adam_mode,
    bool amsgrad>
C10_DEVICE inline void adam_math(
    scalar_type r_args[depth][kILP],
    const opmath_t& lr,
    const opmath_t& beta1,
    const opmath_t& beta2,
    const opmath_t& weight_decay,
    const opmath_t& eps,
    const bool& maximize,
    const float* grad_scale_ptr,
    const float* found_inf_ptr,
    const opmath_t& bias_correction1,
    const opmath_t& bias_correction2_sqrt) {
  static_assert(depth == 4 || depth == 5);
#pragma unroll
  for (int ii = 0; ii < kILP; ii++) {
    // Load values.
    auto param = static_cast<opmath_t>(r_args[kParamIdx][ii]);
    auto grad = static_cast<opmath_t>(r_args[kGradIdx][ii]);
    auto exp_avg = static_cast<opmath_t>(r_args[kExpAvgIdx][ii]);
    auto exp_avg_sq = static_cast<opmath_t>(r_args[kExpAvgSqIdx][ii]);
    opmath_t max_exp_avg_sq;
    if constexpr (amsgrad) {
      max_exp_avg_sq = static_cast<opmath_t>(r_args[kMaxExpAvgSqIdx][ii]);
    }

    // Scale gradient for AMP.
    if (grad_scale_ptr) {
      grad /= (static_cast<opmath_t>(*grad_scale_ptr));
    }
    const opmath_t grad_to_store = grad;

    if (maximize) {
      grad = -grad;
    }

    // Update param, grad, 1st and 2nd order momentum.
    if (weight_decay != 0) {
      if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
        grad += param * weight_decay;
      } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
        param -= lr * weight_decay * param;
      }
    }
    exp_avg = std::fma(beta1, exp_avg, std::fma(-beta1, grad, grad));
    exp_avg_sq =
        std::fma(beta2, exp_avg_sq, std::fma(-beta2, grad * grad, grad * grad));
    const opmath_t step_size = lr / bias_correction1;
    opmath_t denom;
    if constexpr (amsgrad) {
      max_exp_avg_sq = std::max(max_exp_avg_sq, exp_avg_sq);
      denom = (std::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
    } else {
      denom = (std::sqrt(exp_avg_sq) / bias_correction2_sqrt) + eps;
    }
    param -= step_size * exp_avg / denom;

    // Store results.
    r_args[kParamIdx][ii] = param;
    if (grad_scale_ptr) {
      r_args[kGradIdx][ii] = grad_to_store;
    }
    r_args[kExpAvgIdx][ii] = exp_avg;
    r_args[kExpAvgSqIdx][ii] = exp_avg_sq;
    if constexpr (amsgrad) {
      r_args[kMaxExpAvgSqIdx][ii] = max_exp_avg_sq;
    }
  }
}

// [note: Conditional Gradient Store when `optimizer.step` is called by
// GradScaler] When a user is training their model(s) with an FP16 AMP recipe,
// parameter updates are done via `grad_scaler.step(optimizer)` instead of
// `optimizer.step()`. For most optimizers, GradScaler unscales gradients on
// behalf of those optimizers. Also, before `.step`, it makes sure that all the
// gradients involved are finite, which incurs a device sync. On the other hand,
// fused optimizers set their member variable of `_step_supports_amp_scaling` to
// `True` in order to remove the device sync above. This means that fused
// optimizers have to have their CUDA kernels (a) unscale gradients and (b) skip
// parameter updates accordingly. To be functionally on par with `torch.optim`
// optimizers and `_multi_tensor` ones, the kernel below writes out gradients
// only when `grad_scale_ptr != nullptr.
template <typename scalar_type, int depth, ADAM_MODE adam_mode, bool amsgrad>
struct FusedAdamMathFunctor {
  static_assert(
      depth == 4 || depth == 5,
      "depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
  using opmath_t = at::opmath_type<scalar_type>;

  C10_DEVICE __forceinline__ void operator()(
      int64_t chunk_size,
      FusedOptimizerTensorListMetadata<depth>& tl,
      const float* lr_ptr,
      const double& lr,
      const double& beta1,
      const double& beta2,
      const double& weight_decay,
      const double& eps,
      const bool& maximize,
      const float* grad_scale_ptr,
      const float* found_inf_ptr) {
    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];

    const auto lr_opmath =
        lr_ptr ? static_cast<opmath_t>(*lr_ptr) : static_cast<opmath_t>(lr);
    const auto beta1_opmath = static_cast<opmath_t>(beta1);
    const auto beta2_opmath = static_cast<opmath_t>(beta2);
    const auto weight_decay_opmath = static_cast<opmath_t>(weight_decay);
    const auto eps_opmath = static_cast<opmath_t>(eps);

    if (found_inf_ptr && *found_inf_ptr == 1) {
      return;
    }
    const auto [bias_correction1, bias_correction2_sqrt] =
        [&]() -> std::pair<opmath_t, opmath_t> {
      auto step_count = static_cast<opmath_t>(*reinterpret_cast<const float*>(
          tl.state_steps_addresses[tensor_loc]));

      const opmath_t bias_correction1 =
          1 - at::native::pow_(beta1_opmath, step_count);

      const opmath_t bias_correction2 =
          1 - at::native::pow_(beta2_opmath, step_count);
      const opmath_t bias_correction2_sqrt = std::sqrt(bias_correction2);

      return {bias_correction1, bias_correction2_sqrt};
    }();

    scalar_type* args[depth];
    scalar_type r_args[depth][kILP];
    const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size;

    const bool all_aligned{
        init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc)};

    if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) {
      for (int64_t i_start = threadIdx.x;
           i_start * kILP < n && i_start * kILP < chunk_size;
           i_start += blockDim.x) {
#pragma unroll
        for (int i = 0; i < depth; i++) {
          load_store(r_args[i], args[i], 0, i_start);
        }
        adam_math<scalar_type, opmath_t, depth, adam_mode, amsgrad>(
            r_args,
            lr_opmath,
            beta1_opmath,
            beta2_opmath,
            weight_decay_opmath,
            eps_opmath,
            maximize,
            grad_scale_ptr,
            found_inf_ptr,
            bias_correction1,
            bias_correction2_sqrt);
#pragma unroll
        for (int i = 0; i < depth; i++) {
          if (i != kGradIdx || grad_scale_ptr) {
            load_store(args[i], r_args[i], i_start, 0);
          }
        }
      }
    } else {
      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
           i_start += blockDim.x * kILP) {
        load_args<depth>(r_args, args, i_start, chunk_size, n);
        adam_math<scalar_type, opmath_t, depth, adam_mode, amsgrad>(
            r_args,
            lr_opmath,
            beta1_opmath,
            beta2_opmath,
            weight_decay_opmath,
            eps_opmath,
            maximize,
            grad_scale_ptr,
            found_inf_ptr,
            bias_correction1,
            bias_correction2_sqrt);
#pragma unroll
        for (int i = 0; i < depth; i++) {
          if (i != kGradIdx || grad_scale_ptr) {
            store_args(args[i], r_args[i], i_start, chunk_size, n);
          }
        }
      }
    }
  }
};

template <
    typename scalar_type,
    typename param_type,
    typename grad_type,
    typename exp_avg_type,
    typename exp_avg_sq_type,
    typename max_exp_avg_sq_type,
    int depth,
    ADAM_MODE adam_mode,
    bool amsgrad>
struct FusedAdamMathFunctorMP {
  static_assert(
      depth == 4 || depth == 5,
      "depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
  using opmath_t = at::opmath_type<scalar_type>;
  C10_DEVICE __forceinline__ void operator()(
      int64_t chunk_size,
      FusedOptimizerTensorListMetadata<depth>& tl,
      const float* lr_ptr,
      const double& lr,
      const double& beta1,
      const double& beta2,
      const double& weight_decay,
      const double& eps,
      const bool& maximize,
      const float* grad_scale_ptr,
      const float* found_inf_ptr) {
    const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
    const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
    const double lr_double = lr_ptr ? *lr_ptr : lr;

    if (found_inf_ptr && *found_inf_ptr == 1) {
      return;
    }
    const auto [bias_correction1, bias_correction2_sqrt] =
        [&]() -> std::pair<double, double> {
      auto* step_count =
          reinterpret_cast<const float*>(tl.state_steps_addresses[tensor_loc]);
      const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count);
      const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count);
      const auto bias_correction2_sqrt = std::sqrt(bias_correction2);
      return {bias_correction1, bias_correction2_sqrt};
    }();

    param_type* param_args;
    grad_type* grad_args;
    exp_avg_type* exp_avg_args;
    exp_avg_sq_type* exp_avg_sq_args;
    [[maybe_unused]] max_exp_avg_sq_type* max_exp_avg_sq_args;

    // r_args represents the state when everything is casted to scalar_type
    // to be passed into the adam_math function. scalar_type is our operation
    // math type.
    scalar_type r_args[depth][kILP];

    // n = total numel of tensor - what's already been processed
    // so n = numel in current tensor not yet processed
    const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size;

    bool all_aligned = init_args_mixed_prec<
        depth,
        param_type,
        grad_type,
        exp_avg_type,
        exp_avg_sq_type>(
        &param_args,
        &grad_args,
        &exp_avg_args,
        &exp_avg_sq_args,
        tl,
        chunk_idx,
        chunk_size,
        tensor_loc);
    if constexpr (amsgrad) {
      max_exp_avg_sq_args =
          (max_exp_avg_sq_type*)tl.addresses[kMaxExpAvgSqIdx][tensor_loc] +
          chunk_idx * chunk_size;
      all_aligned = all_aligned && is_aligned(max_exp_avg_sq_args);
    }
    if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) {
      for (int64_t i_start = threadIdx.x;
           i_start * kILP < n && i_start * kILP < chunk_size;
           i_start += blockDim.x) {
        if constexpr (!std::is_same_v<scalar_type, param_type>) {
          scalar_type casted_param_args[kILP];
          for (int ii = 0; ii < kILP; ii++) {
            casted_param_args[ii] =
                static_cast<scalar_type>(param_args[ii + i_start * kILP]);
          }
          load_store(r_args[kParamIdx], casted_param_args, 0, 0);
        } else {
          load_store(r_args[kParamIdx], (scalar_type*)param_args, 0, i_start);
        }
        if constexpr (!std::is_same_v<scalar_type, grad_type>) {
          scalar_type casted_grad_args[kILP];
          for (int ii = 0; ii < kILP; ii++) {
            casted_grad_args[ii] =
                static_cast<scalar_type>(grad_args[ii + i_start * kILP]);
          }
          load_store(r_args[kGradIdx], casted_grad_args, 0, 0);
        } else {
          load_store(r_args[kGradIdx], (scalar_type*)grad_args, 0, i_start);
        }
        if constexpr (!std::is_same_v<scalar_type, exp_avg_type>) {
          scalar_type casted_exp_avg_args[kILP];
          for (int ii = 0; ii < kILP; ii++) {
            casted_exp_avg_args[ii] =
                static_cast<scalar_type>(exp_avg_args[ii + i_start * kILP]);
          }
          load_store(r_args[kExpAvgIdx], casted_exp_avg_args, 0, 0);
        } else {
          load_store(
              r_args[kExpAvgIdx], (scalar_type*)exp_avg_args, 0, i_start);
        }
        if constexpr (!std::is_same_v<scalar_type, exp_avg_sq_type>) {
          scalar_type casted_exp_avg_sq_args[kILP];
          for (int ii = 0; ii < kILP; ii++) {
            casted_exp_avg_sq_args[ii] =
                static_cast<scalar_type>(exp_avg_sq_args[ii + i_start * kILP]);
          }
          load_store(r_args[kExpAvgSqIdx], casted_exp_avg_sq_args, 0, 0);
        } else {
          load_store(
              r_args[kExpAvgSqIdx], (scalar_type*)exp_avg_sq_args, 0, i_start);
        }
        if constexpr (amsgrad) {
          if constexpr (!std::is_same_v<scalar_type, max_exp_avg_sq_type>) {
            scalar_type casted[kILP];
            for (int ii = 0; ii < kILP; ii++) {
              casted[ii] = static_cast<scalar_type>(
                  max_exp_avg_sq_args[ii + i_start * kILP]);
            }
            load_store(r_args[kMaxExpAvgSqIdx], casted, 0, 0);
          } else {
            load_store(
                r_args[kMaxExpAvgSqIdx],
                (scalar_type*)max_exp_avg_sq_args,
                0,
                i_start);
          }
        }
        adam_math<scalar_type, opmath_t, depth, adam_mode, amsgrad>(
            r_args,
            lr_double,
            beta1,
            beta2,
            weight_decay,
            eps,
            maximize,
            grad_scale_ptr,
            found_inf_ptr,
            bias_correction1,
            bias_correction2_sqrt);
        if constexpr (!std::is_same_v<scalar_type, param_type>) {
          param_type casted_r_args[kILP];
          for (int ii = 0; ii < kILP; ii++) {
            casted_r_args[ii] = static_cast<param_type>(r_args[kParamIdx][ii]);
          }
          load_store(param_args, casted_r_args, i_start, 0);
        } else {
          load_store(param_args, (param_type*)r_args[kParamIdx], i_start, 0);
        }
        if constexpr (!std::is_same_v<scalar_type, exp_avg_type>) {
          exp_avg_type casted_r_args[kILP];
          for (int ii = 0; ii < kILP; ii++) {
            casted_r_args[ii] =
                static_cast<exp_avg_type>(r_args[kExpAvgIdx][ii]);
          }
          load_store(exp_avg_args, casted_r_args, i_start, 0);
        } else {
          load_store(
              exp_avg_args, (exp_avg_type*)r_args[kExpAvgIdx], i_start, 0);
        }
        if constexpr (!std::is_same_v<scalar_type, exp_avg_sq_type>) {
          exp_avg_sq_type casted_r_args[kILP];
          for (int ii = 0; ii < kILP; ii++) {
            casted_r_args[ii] =
                static_cast<exp_avg_sq_type>(r_args[kExpAvgSqIdx][ii]);
          }
          load_store(exp_avg_sq_args, casted_r_args, i_start, 0);
        } else {
          load_store(
              exp_avg_sq_args,
              (exp_avg_sq_type*)r_args[kExpAvgSqIdx],
              i_start,
              0);
        }
        if constexpr (amsgrad) {
          if constexpr (!std::is_same_v<scalar_type, max_exp_avg_sq_type>) {
            max_exp_avg_sq_type casted[kILP];
            for (int ii = 0; ii < kILP; ii++) {
              casted[ii] =
                  static_cast<max_exp_avg_sq_type>(r_args[kMaxExpAvgSqIdx][ii]);
            }
            load_store(max_exp_avg_sq_args, casted, i_start, 0);
          } else {
            load_store(
                max_exp_avg_sq_args,
                (max_exp_avg_sq_type*)r_args[kMaxExpAvgSqIdx],
                i_start,
                0);
          }
        }
        if (grad_scale_ptr) {
          if constexpr (!std::is_same_v<scalar_type, grad_type>) {
            grad_type casted_r_args[kILP];
            for (int ii = 0; ii < kILP; ii++) {
              casted_r_args[ii] = static_cast<grad_type>(r_args[kGradIdx][ii]);
            }
            load_store(grad_args, casted_r_args, i_start, 0);
          } else {
            load_store(grad_args, (grad_type*)r_args[kGradIdx], i_start, 0);
          }
        }
      }
    } else {
      for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
           i_start += blockDim.x * kILP) {
        load_args<
            scalar_type,
            param_type,
            grad_type,
            exp_avg_type,
            exp_avg_sq_type>(
            r_args,
            param_args,
            grad_args,
            exp_avg_args,
            exp_avg_sq_args,
            i_start,
            chunk_size,
            n);
        if constexpr (amsgrad) {
#pragma unroll
          for (int ii = 0; ii < kILP; ii++) {
            const auto i = i_start + threadIdx.x + ii * blockDim.x;
            r_args[kMaxExpAvgSqIdx][ii] = 0;
            if (i < n && i < chunk_size) {
              r_args[kMaxExpAvgSqIdx][ii] =
                  static_cast<scalar_type>(max_exp_avg_sq_args[i]);
            }
          }
        }
        adam_math<scalar_type, opmath_t, depth, adam_mode, amsgrad>(
            r_args,
            lr_double,
            beta1,
            beta2,
            weight_decay,
            eps,
            maximize,
            grad_scale_ptr,
            found_inf_ptr,
            bias_correction1,
            bias_correction2_sqrt);
        store_args(param_args, r_args[kParamIdx], i_start, chunk_size, n);
        store_args(exp_avg_args, r_args[kExpAvgIdx], i_start, chunk_size, n);
        store_args(
            exp_avg_sq_args, r_args[kExpAvgSqIdx], i_start, chunk_size, n);
        if constexpr (amsgrad) {
          store_args(
              max_exp_avg_sq_args,
              r_args[kMaxExpAvgSqIdx],
              i_start,
              chunk_size,
              n);
        }
        if (grad_scale_ptr) {
          store_args(grad_args, r_args[kGradIdx], i_start, chunk_size, n);
        }
      }
    }
  }
};

} // namespace

} // namespace at::native

#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
