# mypy: allow-untyped-defs
import logging
import threading
from collections.abc import Callable, Sequence
from contextlib import nullcontext
from functools import lru_cache
from itertools import chain
from typing import cast

import torch
from torch._guards import detect_fake_mode
from torch._logging import LazyString
from torch._ops import OpOverload
from torch._subclasses import FakeTensorMode
from torch.distributed._functional_collectives import _are_we_tracing
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._decompositions import DecompShardingStrategy
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
    OpInfo,
    OpSchema,
    OpSpec,
    OpStrategy,
    OutputSharding,
    OutputSpecType,
    RuntimeSchemaInfo,
    StrategyType,
    TupleStrategy,
)
from torch.distributed.tensor._ops.single_dim_strategy import (
    _expand_single_dim_strategy_to_mesh,
    _SingleDimStrategyInfo,
)
from torch.distributed.tensor._utils import (
    compute_local_shape_and_global_offset,
    compute_local_stride,
    try_find_mesh_from_args,
)
from torch.distributed.tensor.placement_types import _StridedShard, Shard
from torch.utils._pytree import tree_map


aten = torch.ops.aten

log = logging.getLogger(__name__)


def _propagate_use_strided_shard_flag(
    op_strategy: OpStrategy,
    op_schema: OpSchema,
) -> None:
    """Propagate use_strided_shard_as_shard_order from input specs to output specs.

    When inputs carry _StridedShard with an explicit flag, all output (and input)
    DTensorSpecs in the strategy that also contain _StridedShard must agree.
    Strategy functions may forget to propagate the flag; this function fixes
    them up centrally after the strategy is produced.
    """
    _use_strided: bool | None = None
    for spec in op_schema.args_spec:
        if any(isinstance(p, _StridedShard) for p in spec.placements):
            val = spec.use_strided_shard_as_shard_order
            if _use_strided is not None and _use_strided != val:
                raise ValueError(
                    "Conflicting use_strided_shard_as_shard_order across "
                    f"input specs: got both {_use_strided} and {val}"
                )
            _use_strided = val

    if _use_strided is None:
        return

    def _fixup(spec: DTensorSpec) -> None:
        if not any(isinstance(p, _StridedShard) for p in spec.placements):
            return
        if spec.use_strided_shard_as_shard_order == _use_strided:
            return
        spec.use_strided_shard_as_shard_order = _use_strided
        if _use_strided:
            spec.shard_order = None  # pyrefly: ignore[bad-assignment]
        else:
            spec.shard_order = DTensorSpec.compute_default_shard_order(spec.placements)

    for op_spec in op_strategy.strategies:
        out = op_spec.output_specs
        if out is not None:
            if isinstance(out, DTensorSpec):
                _fixup(out)
            else:
                for s in out:
                    if s is not None:
                        _fixup(s)
        if op_spec.input_specs is not None:
            for s in op_spec.input_specs:
                _fixup(s)


def _length(obj) -> int:
    if obj is None:
        return 0
    if not isinstance(obj, Sequence):
        return 1
    return len(obj)


def _get_expected_num_tensor_outputs(op: OpOverload) -> int | None:
    """
    Get the expected number of tensor outputs for an operator based on its schema.

    Returns:
        The number of tensor outputs expected. Returns 0 for ops that don't return tensors
        (e.g., _linalg_check_errors). Returns 1 for single tensor return, and >1 for
        tuple returns where each element is a tensor. Returns None for List[Tensor]
        returns where the length is unknown at schema time.
    """
    return_types = op._schema.returns
    if len(return_types) == 0:
        return 0

    first_return = return_types[0]
    if isinstance(first_return.type, torch.TensorType):
        # Could be single tensor or tuple of tensors
        return len(return_types)
    elif isinstance(first_return.type, torch.ListType):
        # List[Tensor] - we don't know the length at schema time
        return None
    else:
        # Not a tensor return type
        return 0


def _validate_tensor_meta_count(
    op_schema: OpSchema,
    tensor_meta: TensorMeta | Sequence[TensorMeta | None] | None,
) -> None:
    """
    Validate that the tensor_meta matches the expected number of outputs for the op.

    Raises AssertionError if the count doesn't match, providing a helpful error message.
    """
    expected_outputs = _get_expected_num_tensor_outputs(op_schema.op)

    # Compute actual count:
    # - None means 0 outputs
    # - TensorMeta (single instance) means 1 output
    # - Sequence of TensorMeta means len(sequence) outputs
    # Note: TensorMeta is a NamedTuple (subclass of tuple), so we must check for it first
    if tensor_meta is None:
        actual_outputs = 0
    elif isinstance(tensor_meta, TensorMeta):
        actual_outputs = 1
    else:
        actual_outputs = len(tensor_meta)

    if expected_outputs is None:
        # List[Tensor] return type: length unknown at schema time, but
        # tensor_meta must be a list of TensorMeta.
        if not isinstance(tensor_meta, list):
            raise AssertionError(
                f"Tensor meta for {op_schema.op} should be a list[TensorMeta] "
                f"(op returns List[Tensor]), but got {type(tensor_meta).__name__}"
            )
        return

    if actual_outputs != expected_outputs:
        raise AssertionError(
            f"Tensor meta count mismatch for {op_schema.op}: "
            f"expected {expected_outputs} tensor output(s) based on op schema, "
            f"but _propagate_tensor_meta returned {actual_outputs}. "
            f"This usually indicates a bug in fake tensor propagation for this op."
        )


class LocalLRUCache(threading.local):
    def __init__(self, user_function: Callable) -> None:
        self.cache = lru_cache(None)(user_function)

    def __call__(self, *args, **kwargs) -> object:
        # Fast path: log.handlers check is very cheap (just checking if list is non-empty)
        # Only do the more expensive isEnabledFor check if handlers exist
        if log.handlers and log.isEnabledFor(logging.DEBUG):
            info_before = self.cache.cache_info()
            result = self.cache(*args, **kwargs)
            info_after = self.cache.cache_info()
            cache_hit = info_after.hits > info_before.hits
            op_schema = args[0] if args else None
            output_spec = getattr(result, "output_spec", None)
            log.debug(
                "sharding_prop python cache %s: %s -> %s",
                "HIT" if cache_hit else "MISS",
                op_schema,
                output_spec,
            )
            return result
        return self.cache(*args, **kwargs)

    def cache_info(self):
        return self.cache.cache_info()

    def cache_clear(self):
        return self.cache.cache_clear()


def _format_unbacked_hinting_log(
    op_schema: OpSchema,
    strategies: list[OpSpec],
    strategy_index: int,
    replacements: dict,
) -> str:
    """Format log message for unbacked hinting strategy selection (only called if debug logging enabled)."""
    args_spec = tuple(str(spec) for spec in op_schema.args_schema)
    strat = strategies[strategy_index]
    if strat.input_specs is None:
        placements_in = None
    else:
        placements_in = tuple(
            spec.format_shard_order_str(spec.placements, spec.shard_order)
            for spec in strat.input_specs
        )
    placements_out = tree_map(
        lambda spec: spec.format_shard_order_str(spec.placements, spec.shard_order),
        strat.output_specs,
        is_leaf=lambda x: isinstance(x, DTensorSpec),
    )
    return (
        f"Selected strategy {placements_in} -> {placements_out} "
        f"for {op_schema.op} with input {args_spec}, using unbacked hints: {replacements}"
    )


def _select_min_redistribute_cost(
    costs: list[torch.types.FloatLikeType],
    strategies: list[OpSpec],
    op_schema: OpSchema | None = None,
) -> int:
    """
    Given a list of costs and corresponding op strategies, selects the minimum cost strategy, returning the index.
    If unbacked symbols are involved, replaces them with known upper-bound values, falling back to hardcoded values.
    """
    from torch.fx.experimental.symbolic_shapes import (
        free_unbacked_symbols,
        is_concrete_float,
    )
    from torch.utils._sympy.interp import sympy_interp
    from torch.utils._sympy.numbers import int_oo
    from torch.utils._sympy.reference import PythonReferenceAnalysis

    int_fallback = 8192
    free_unbacked = list(set(chain(*[free_unbacked_symbols(cost) for cost in costs])))

    # Easy path: no unbacked shapes involved, choose min cost strategy.
    # Doing the hard path for backed could also make sense?
    if all(is_concrete_float(c) for c in costs) or not free_unbacked:
        return costs.index(min(costs))

    # Figure out heuristic hints for unbacked shapes.
    # If available, use shape upper bound. If not, fallback to some integer (inductor size-hinting style).
    shape_env = next(iter(x for x in costs if not is_concrete_float(x))).node.shape_env  # type: ignore[arg-type]
    replacements = {}
    for sym in free_unbacked:
        # TODO(laithsakka): unify with optimization_hint API
        if (hint := shape_env.var_to_hint_override.get(sym)) is not None:
            replacements[sym] = hint
        elif (upper := shape_env.bound_sympy(sym).upper) is not int_oo:
            replacements[sym] = upper
        else:
            replacements[sym] = int_fallback

    # Use replacements for redistribute cost hints
    proxy_costs = [
        float(cost)
        if is_concrete_float(cost)
        else sympy_interp(
            PythonReferenceAnalysis,
            replacements,
            cost.node.expr.xreplace(replacements),  # type: ignore[arg-type]
        )
        for cost in costs
    ]
    min_cost = min(proxy_costs)
    strategy_index = proxy_costs.index(min_cost)

    if op_schema:
        log.debug(
            "%s",
            LazyString(
                _format_unbacked_hinting_log,
                op_schema,
                strategies,
                strategy_index,
                replacements,
            ),
        )
    return strategy_index


def _select_min_cost_strategy(
    strategy: OpStrategy, op_schema: OpSchema | None = None
) -> OpSpec:
    from torch.fx.experimental.symbolic_shapes import guard_or_false

    if len(strategy.strategies) == 1:
        # short cut with only one possible OpSpec
        return strategy.strategies[0]

    op_spec_costs: list[torch.types.FloatLikeType] = []
    no_redistribute_strategy_index: int = -1
    negative_cost_index: int = -1
    zero_cost_index: int = -1
    for strategy_idx, op_spec in enumerate(strategy.strategies):
        if op_spec.redistribute_cost is None:
            raise AssertionError("must set redistribute cost each OpSpec!")
        redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost))
        op_spec_costs.append(redistribute_cost)

        # If there are strategies with negative/zero/no redistribute cost,
        # we record those indices.
        # TODO: Currently this only applies to OpStrategy selection. Requires extra
        # logic to make it work for TupleStrategy, if needed.
        if op_schema is not None:
            if guard_or_false(redistribute_cost < 0):
                if (
                    negative_cost_index == -1
                    or redistribute_cost < op_spec_costs[negative_cost_index]
                ):
                    negative_cost_index = strategy_idx
            elif guard_or_false(redistribute_cost == 0):
                needs_redistribute = False
                for spec_idx, input_spec in enumerate(op_schema.args_spec):
                    desired_spec = (
                        op_spec.output_spec
                        if op_spec.input_specs is None
                        else op_spec.input_specs[spec_idx]
                    )
                    if input_spec.placements != desired_spec.placements:
                        needs_redistribute = True
                        break

                if not needs_redistribute:
                    no_redistribute_strategy_index = strategy_idx
                elif zero_cost_index == -1:
                    zero_cost_index = strategy_idx

    # prioritize negative/zero/no redistribute cost strategies
    if negative_cost_index != -1:
        # If there's negative cost, we select the one with the minimal cost,
        # even if this means we need to redistribute, e.g. via local chunking.
        # E.g. this can happen for ops in self.op_to_shape_and_stride_idx
        # when the inputs / outputs are sharded.
        selected_strategy_index = negative_cost_index
    elif no_redistribute_strategy_index != -1:
        selected_strategy_index = no_redistribute_strategy_index
    elif zero_cost_index != -1:
        selected_strategy_index = zero_cost_index
    else:
        # default to choosing minimal redistribute cost
        selected_strategy_index = _select_min_redistribute_cost(
            op_spec_costs, strategy.strategies, op_schema
        )

    return strategy.strategies[selected_strategy_index]


class ShardingPropagator:
    # Lock to protect FakeTensorMode context during tensor meta propagation.
    # By default this is a no-op (nullcontext) for performance. Multi-threaded
    # tests should set this to threading.Lock() to prevent race conditions
    # when multiple threads enter different FakeTensorMode contexts.
    _fake_mode_lock = nullcontext()

    def __init__(self) -> None:
        self.op_to_rules: dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
        self.op_strategy_funcs: dict[
            OpOverload,
            Callable[[OpSchema], StrategyType],
        ] = {}
        self.op_single_dim_strategy_funcs: dict[
            OpOverload,
            _SingleDimStrategyInfo,
        ] = {}
        # op map to save static argnum to decide to reuse sharding prop cache or
        # re-run sharding prop
        self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {}
        self.op_to_schema_info_for_single_dim_strategy: dict[
            OpOverload, RuntimeSchemaInfo
        ] = {}
        self.propagate_op_sharding = LocalLRUCache(
            self.propagate_op_sharding_non_cached
        )
        self.decomp_strategy = DecompShardingStrategy(self)
        # op map to save indices of shape (and stride) args which may need to be
        # modified in sharding prop
        self.op_to_shape_and_stride_idx: dict[OpOverload, int | tuple[int, int]] = {
            # new factory ops
            aten.new_empty.default: 1,
            aten.new_full.default: 1,
            aten.new_ones.default: 1,
            aten.new_zeros.default: 1,
            aten.new_empty_strided.default: (1, 2),
            # view ops
            aten.expand.default: 1,
            aten.expand_copy.default: 1,
            aten.reshape.default: 1,
            aten.view.default: 1,
            aten.view_copy.default: 1,
            aten._unsafe_view.default: 1,
            aten.select_backward.default: 1,
            aten.slice_backward.default: 1,
        }
        # squeeze ops that need dim arg rewritten to only globally-singleton dims
        self.squeeze_op_to_dims_variant: dict[OpOverload, OpOverload] = {
            aten.squeeze.default: aten.squeeze.dims,
            aten.squeeze.dim: aten.squeeze.dims,
            aten.squeeze.dims: aten.squeeze.dims,
            aten.squeeze_.default: aten.squeeze_.dims,
            aten.squeeze_.dim: aten.squeeze_.dims,
            aten.squeeze_.dims: aten.squeeze_.dims,
        }

    def register_sharding_prop_rule(
        self,
        op_overload: OpOverload,
        rule_func: Callable[[OpSchema], OutputSharding],
        schema_info: RuntimeSchemaInfo | None = None,
    ):
        """
        Register a sharding propagation rule for an operator.
        """
        self.op_to_rules[op_overload] = rule_func
        if schema_info is not None:
            self.op_to_schema_info[op_overload] = schema_info

    def register_single_dim_op_strategy(
        self,
        op_overload: OpOverload,
        strategy_info: _SingleDimStrategyInfo,
        schema_info: RuntimeSchemaInfo | None = None,
    ):
        """
        Register a strategy over a single mesh-dim, relying on infra to automatically expand to the full mesh.
        """
        self.op_single_dim_strategy_funcs[op_overload] = strategy_info
        if schema_info is not None:
            self.op_to_schema_info_for_single_dim_strategy[op_overload] = schema_info

    def register_op_strategy(
        self,
        op_overload: OpOverload,
        strategy_func: Callable[[OpSchema], StrategyType],
        schema_info: RuntimeSchemaInfo | None = None,
    ):
        """
        Register a :class:`OpStrategy` generator for an operator.

        During the sharding propagation, DTensor wants to enumerate all
        acceptable sharding specs (:class:`OpSpec`) for an operator,
        and by "acceptable" we mean that the operator can be executed on
        the ``_local_tensor`` of DTensor args/kwargs (with ``OpSpec.input_specs``)
        and the output(s) constitute valid DTensor(s) (with ``OpSpec.output_specs``).

        ``strategy_func`` is the function that enumerates such acceptable specs
        for the operator ``op_overload``. One general approach to write ``strategy_func``
        is, if the operator has simple arguments structure (e.g. mm, bmm), first enumerating
        all sharding specs for the operands, and then filtering out the ones that
        are not valid. For example, for ``mm``, the operands are two 2D tensors, and
        if both ``input`` and ``mat2`` have sharding placements ``[Shard(0)]``, then this
        is not an acceptable ``input_specs``.

        Once we have a way to enumerate all acceptable sharding specs, we can use each
        of them to construct a :class:`OpSpec`. The ``OpSpec.input_specs`` directly comes
        from the sharding spec, and the ``OpSpec.output_specs`` is therefore determined
        (e.g. ``[Shard(1)]`` @ ``[Shard(0)]`` yields ``[Partial()]``). In addition,
        :class:`OpSpec` also contains ``redistribute_cost`` which records the redistribution
        cost from each :class:`OpSpec` in the source :class:`OpStrategy.strategies` to
        the target sharding spec, for each operand.

        The ``strategy_func`` should return a :class:`OpStrategy` which contains a list of
        all the :class:`OpSpec`s generated in the above.

        The optional ``schema_info`` tells which non-DTensor args/kwargs could affect the
        cache and whether ``pytree`` is needed to flatten the nested args. ``static_argnum``
        marks the starting index of the non-DTensor args that should be hashed into the
        sharding propagation hash key, and ``static_kwargkey`` marks the keys of the
        non-DTensor kwargs that should be hashed. ``needs_pytree`` should be used when
        the input arg has :class:`list` or :class:`dict` structure.

        For example, ``aten.cat.default`` op has a ``List[Tensor]`` argument ``tensors``
        and an ``int`` argument ``dim``. Because ``dim`` affects the sharding propagation
        result, we want to pass ``RuntimeSchemaInfo(static_argnum=1)`` because the argument
        index of ``dim`` is 1. Besides, we also want to set ``needs_pytree=True`` because
        ``tensors`` needs be flattened in sharding propagation. Another example is
        ``aten.histc.default``. ``histc`` has 4 arguments (self, bins, min, max) and the
        last two would affect sharding propagation along with the :class:`DTensor` argument
        ``self``. Since the argument index of ``min`` is 2, the `schema_info` should be
        `RuntimeSchemaInfo(static_argnum=2)`.
        """
        self.op_strategy_funcs[op_overload] = strategy_func
        if schema_info is not None:
            self.op_to_schema_info[op_overload] = schema_info

    def _propagate_tensor_meta_non_cached(
        self, op_schema: OpSchema
    ) -> TensorMeta | Sequence[TensorMeta | None] | None:
        """
        Propagate the tensor metadata, it could either return a TensorMeta
        or a list/tuple of TensorMetas
        """
        if op_schema.op == aten.equal.default:
            # data dependent ops can't be used for fake propagation
            return None

        # NOTE: We must call the tracing in fake tensor mode so that it avoids
        # materializing memory.
        # NOTE: Use _fake_mode_lock to serialize access when running in
        # multi-threaded tests (lock must be set to threading.Lock()).
        # This is a nullcontext by default.
        with ShardingPropagator._fake_mode_lock:
            fake_mode = detect_fake_mode() or FakeTensorMode()
            with fake_mode:
                fake_args = op_schema.gen_fake_args()
                fake_kwargs = op_schema.gen_fake_kwargs()
                fake_out = op_schema.op(*fake_args, **fake_kwargs)

        if isinstance(fake_out, torch.Tensor):
            return TensorMeta(
                shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype
            )

        elif isinstance(fake_out, (tuple, list)):
            tensor_meta_list: list[TensorMeta | None] = []
            for fake_out_item in fake_out:
                if isinstance(fake_out_item, torch.Tensor):
                    tensor_meta_list.append(
                        TensorMeta(
                            shape=fake_out_item.shape,
                            stride=fake_out_item.stride(),
                            dtype=fake_out_item.dtype,
                        )
                    )
                else:
                    tensor_meta_list.append(None)
            return (
                tuple(tensor_meta_list)
                if isinstance(fake_out, tuple)
                else tensor_meta_list
            )
        else:
            # if fake is not a tensor or tuple of tensor, return as none
            return None

    @lru_cache  # noqa: B019
    def _propagate_tensor_meta_cached(
        self, op_schema: OpSchema
    ) -> TensorMeta | Sequence[TensorMeta | None] | None:
        """
        Cached version of _propagate_tensor_meta_non_cached
        Use _propagate_tensor_meta instead to handle dynamic shapes.
        """
        return self._propagate_tensor_meta_non_cached(op_schema)

    def _propagate_tensor_meta(
        self, op_schema: OpSchema
    ) -> TensorMeta | Sequence[TensorMeta | None] | None:
        """
        Propagate the tensor metadata, it could either return a TensorMeta
        or a list/tuple of TensorMetas. Uses the cached version if not
        actively tracing. Use this method instead of _propagate_tensor_meta_non_cached
        """
        if _are_we_tracing():
            return self._propagate_tensor_meta_non_cached(op_schema)
        else:
            return self._propagate_tensor_meta_cached(op_schema)

    def _create_output_spec_with_new_tensor_meta(
        self,
        op: OpOverload,
        output_specs: OutputSpecType,
        output_tensor_meta: TensorMeta | Sequence[TensorMeta | None] | None,
    ) -> OutputSpecType:
        """
        Wrap the output_specs with the tensor metadata from the output.
        """

        if isinstance(output_specs, DTensorSpec):
            if not isinstance(output_tensor_meta, TensorMeta):
                # Either error due to ShardingPropagator or due to incorrect OutputSpec
                if not isinstance(output_tensor_meta, (tuple, list)):
                    raise ValueError(
                        "ShardingPropagator error: output does not have an associated "
                        "TensorMeta"
                    )
                raise ValueError(
                    f"For the op {op.name()}, `output_specs` has 1 output which does "
                    "not equal the "
                    f"number of op outputs: {len(output_tensor_meta)}."
                )
            return output_specs.shallow_copy_with_tensor_meta(output_tensor_meta)
        elif isinstance(output_specs, (tuple, list)):
            new_specs: list[DTensorSpec | None] = []
            if not isinstance(output_tensor_meta, (tuple, list)) or len(
                output_specs
            ) != len(output_tensor_meta):
                raise ValueError(
                    f"For the op {op.name()}, `output_specs` has {len(output_specs)} "
                    "outputs which does not equal the "
                    f"number of op outputs {_length(output_tensor_meta)}."
                )

            # pyrefly: ignore [bad-argument-type]
            for i, spec in enumerate(output_specs):
                if isinstance(spec, DTensorSpec):
                    output_tensor_meta_i = output_tensor_meta[i]
                    if not isinstance(output_tensor_meta_i, TensorMeta):
                        # Some ops (e.g. convolution_backward, native_layer_norm_backward,
                        # _fused_rms_norm_backward) have an output_mask parameter that
                        # controls which outputs are computed. When output_mask[i] is
                        # False, the output at position i is None and has no TensorMeta.
                        if output_tensor_meta_i is None:
                            new_specs.append(None)
                            continue
                        else:
                            raise ValueError(
                                f"ShardingPropagator error: output {i} of {op.name()} "
                                "does not have an associated TensorMeta"
                            )

                    new_specs.append(
                        spec.shallow_copy_with_tensor_meta(output_tensor_meta_i)
                    )
                else:
                    new_specs.append(spec)

            return tuple(new_specs)
        else:
            if output_specs is not None:
                raise AssertionError
            return output_specs

    def _wrap_with_op_strategy(self, op_schema: OpSchema) -> OpSchema:
        """
        wrap a op_schema that contains DTensorSpec to another op_schema that contains
        OpStrategy/TupleStrategy, the returned op_schema is then used for sharding
        strategy propagation on pytorch operators.
        """

        def spec_to_strategy(spec: object) -> object:
            if isinstance(spec, DTensorSpec):
                return OpStrategy([OpSpec(spec)])
            elif isinstance(spec, (list, tuple)) and len(spec) > 0:
                if all(isinstance(s, DTensorSpec) for s in spec):
                    # tensor list create tuple strategy
                    tuple_strategy = [spec_to_strategy(s) for s in spec]
                    tuple_strategy = cast(Sequence[StrategyType], tuple_strategy)
                    return TupleStrategy(
                        tuple(tuple_strategy)
                        if isinstance(spec, tuple)
                        else tuple_strategy
                    )
                elif any(isinstance(s, DTensorSpec) for s in spec):
                    # mixed list (e.g. [DTensorSpec, None, DTensorSpec]) for
                    # ops like aten.index.Tensor; keep as list so pytree
                    # flattening can extract OpStrategy items
                    return [spec_to_strategy(s) for s in spec]
                else:
                    return spec
            else:
                return spec

        args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema]

        kwargs_op_strategy = {
            k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items()
        }

        return OpSchema(
            op=op_schema.op,
            args_schema=tuple(args_op_strategy),
            kwargs_schema=kwargs_op_strategy,
            schema_info=op_schema.schema_info,
        )

    def propagate(self, op_info: OpInfo) -> None:
        # NB: The logic here is duplicated in _propagate_op_sharding_dispatch_slow_path.
        # Ideally, this function would be deleted, but there are a handful of
        # one off call sites here that aren't cleaned up.

        # NOTE: schema should always be populated when calling this function,
        # as it's only called after unwrap_to_op_info (create_schema=True).
        if op_info.schema is None:
            raise AssertionError(
                "op_info.schema should not be None in propagate. "
                "This function should only be called after unwrap_to_op_info."
            )

        # We cannot use an lru cache if we know that inputs will have dynamic shapes,
        # because SymInts are not hashable.
        # This is generally ok because this only happens during tracing in torch.compile,
        # and tracing does not need to be as fast as eagermode DTensor usages.
        if _are_we_tracing():
            output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
        else:
            output_sharding = cast(
                OutputSharding, self.propagate_op_sharding(op_info.schema)
            )
        op_info.output_sharding = output_sharding

    def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
        """
        Propagate the sharding for an operator given the op_schema.
        """
        # no-op in OSS, logs API usage metrics in meta-internal runs
        torch._C._log_api_usage_once(
            "torch.distributed.tensor._sharding_prop.ShardingPropagator.propogate_op_sharding_non_cached"
        )
        # special case op, we don't need to propagate for local
        # scalar. TODO: figure out a better way to handle this
        if op_schema.op is aten._local_scalar_dense.default:
            return OutputSharding(None, op_schema)

        out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)

        single_dim_strategy_info = self.op_single_dim_strategy_funcs.get(op_schema.op)
        op_strategy_func = self.op_strategy_funcs.get(op_schema.op)
        decomp_exception = None
        if single_dim_strategy_info is not None or op_strategy_func is not None:
            # Validate that tensor_meta count matches expected outputs from op schema.
            # This catches bugs in fake tensor propagation early.
            if single_dim_strategy_info is not None:
                _validate_tensor_meta_count(op_schema, out_tensor_meta)
            """
            Given the single_dim_strategy, which is just a minimal set of valid input-output placement specifications
            for the operator over a single mesh dimension,

            And the OpSchema, which includes information about the runtime input tensor placements, and the mesh,

            Combine single_dim_strategies across mesh dims, also expanding placeholders (ShardPlaceholder) to any real
            sharding types in op_schema, and find the lowest cost redistribution of inputs to match a valid strategy
            combination.
            """
            # wrap the op_schema with op strategy for sharding strategy propagation
            strategy_schema = self._wrap_with_op_strategy(op_schema)

            if single_dim_strategy_info is not None:
                mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema)
                if not isinstance(mesh, DeviceMesh):
                    raise AssertionError("Expected to find a valid mesh")
                # expand to generate the full set of strategy combinations, each one
                # with a redistribute cost, and then find the min strategy over those costs.
                _expanded_strategy_fn = _expand_single_dim_strategy_to_mesh(
                    mesh, strategy_schema, single_dim_strategy_info, out_tensor_meta
                )
                op_strategy = _expanded_strategy_fn(
                    op_schema.op, strategy_schema.args_meta, strategy_schema.kwargs_meta
                )
            else:
                if op_strategy_func is None:
                    raise AssertionError
                op_strategy = op_strategy_func(strategy_schema)

        else:
            # try operator decomposition path

            op_strategy = None
            if DecompShardingStrategy.has_decomp(op_schema.op):
                # Ensure schema_info is registered for proper cache key computation
                self.decomp_strategy.ensure_schema_info(op_schema.op)
                try:
                    op_strategy = self.decomp_strategy.propagate_strategy(
                        op_schema,
                    )
                except Exception as e:
                    decomp_exception = e

        if op_strategy is not None:
            if isinstance(op_strategy, OpStrategy):
                _propagate_use_strided_shard_flag(op_strategy, op_schema)
                # single Op strategy
                output_strategy = _select_min_cost_strategy(op_strategy, op_schema)

                # check if we need to redistribute the input
                needs_redistribute = False
                # check if we want to use args value from redistribute_schema
                use_val_from_redistribute_schema = False
                expected_input_specs: list[DTensorSpec] = []

                # in case where the op does not specify input_specs and output_specs
                # is a DTensorSpec, we use output_specs as the spec for each DTensor
                # input arg.
                if output_strategy.input_specs is None:
                    if not isinstance(output_strategy.output_specs, DTensorSpec):
                        raise AssertionError

                for idx, input_spec in enumerate(op_schema.args_spec):
                    desired_spec = (
                        output_strategy.output_spec
                        if output_strategy.input_specs is None
                        else output_strategy.input_specs[idx]
                    )
                    expected_input_specs.append(
                        desired_spec.shallow_copy_with_tensor_meta(
                            input_spec.tensor_meta
                        )
                    )
                    if input_spec.placements != desired_spec.placements:
                        needs_redistribute = True

                suggestion_schema = None
                if needs_redistribute:
                    suggestion_schema = OpSchema(
                        op_schema.op, tuple(expected_input_specs), {}
                    )
                    suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)

                # shape and stride args need to be modified for
                # view ops and new factory ops, potentially
                if op_schema.op in self.op_to_shape_and_stride_idx:
                    if not isinstance(output_strategy.output_spec, DTensorSpec):
                        raise AssertionError
                    # It happens when the output has the same shape as the input
                    # and the input placements are not all Replicate().
                    if any(
                        isinstance(p, Shard | _StridedShard)
                        for p in output_strategy.output_spec.placements
                    ):
                        schema = suggestion_schema or op_schema
                        if not isinstance(out_tensor_meta, TensorMeta):
                            raise AssertionError
                        suggestion_schema = self._adjust_shape_and_stride_args(
                            out_tensor_meta, schema, output_strategy.output_spec
                        )
                        needs_redistribute = True
                        use_val_from_redistribute_schema = True

                # rewrite squeeze to use only globally-singleton dims
                if op_schema.op in self.squeeze_op_to_dims_variant:
                    schema = suggestion_schema or op_schema
                    adjusted = self._adjust_squeeze_to_global_singletons(schema)
                    if adjusted is not None:
                        suggestion_schema = adjusted
                        needs_redistribute = True
                        use_val_from_redistribute_schema = True

                # construct output spec for the op
                if op_schema.return_type_tuple_tensor_like():
                    # for ops that return multiple tensors and the output_specs is not
                    # a tuple, we use a tuple of that single output spec as the new
                    # output_specs
                    output_specs: OutputSpecType = output_strategy.output_specs
                    if isinstance(output_specs, DTensorSpec):
                        output_specs = tuple(
                            # create a new DTensorSpec with the same placement as the
                            # output_specs in output_strategy
                            DTensorSpec(
                                mesh=output_specs.mesh,
                                placements=output_specs.placements,
                                tensor_meta=output_specs.tensor_meta,
                                use_strided_shard_as_shard_order=output_specs.use_strided_shard_as_shard_order,
                            )
                            for _ in range(len(op_schema.op._schema.returns))
                        )
                elif (
                    op_schema.return_type_tensor()
                    or op_schema.return_type_list_tensor_like()
                ):
                    output_specs = output_strategy.output_specs
                else:
                    output_specs = None

                output_sharding = OutputSharding(
                    output_specs,
                    suggestion_schema,
                    needs_redistribute=needs_redistribute,
                    use_val_from_redistribute_schema=use_val_from_redistribute_schema,
                )
            elif isinstance(op_strategy, TupleStrategy):
                # tuple strategy output sharding processing
                # runtime select OpSpec for each TupleStrategy input arg
                selected_strategies: list[OpSpec] = []
                out_spec_list: list[DTensorSpec] = []
                for strategy in op_strategy.children:
                    if not isinstance(strategy, OpStrategy):
                        raise AssertionError
                    _propagate_use_strided_shard_flag(strategy, op_schema)
                    selected_strategy = _select_min_cost_strategy(strategy)
                    selected_strategies.append(selected_strategy)
                    if selected_strategy.output_specs is not None:
                        out_spec_list.append(selected_strategy.output_spec)

                needs_redistribute = False
                suggestion_args: list[object] = []
                tensor_or_list_tensor_arg_idx = 0

                for arg in op_schema.args_schema:
                    if (
                        arg
                        and isinstance(arg, (list, tuple))
                        and isinstance(arg[0], DTensorSpec)
                    ):
                        expected_input_spec_list: list[DTensorSpec] = []
                        for idx, arg_spec in enumerate(arg):
                            expected_input_spec = selected_strategies[idx].input_spec(
                                tensor_or_list_tensor_arg_idx
                            )
                            expected_input_spec = (
                                expected_input_spec.shallow_copy_with_tensor_meta(
                                    arg_spec.tensor_meta
                                )
                            )
                            if arg_spec.placements != expected_input_spec.placements:
                                needs_redistribute = True
                            expected_input_spec_list.append(expected_input_spec)
                        suggestion_args.append(
                            tuple(expected_input_spec_list)
                            if isinstance(arg, tuple)
                            else expected_input_spec_list
                        )
                        tensor_or_list_tensor_arg_idx += 1

                    elif isinstance(arg, DTensorSpec):
                        expected_input_spec = selected_strategies[0].input_spec(
                            tensor_or_list_tensor_arg_idx
                        )
                        expected_input_spec = (
                            expected_input_spec.shallow_copy_with_tensor_meta(
                                arg.tensor_meta
                            )
                        )
                        if arg.placements != expected_input_spec.placements:
                            needs_redistribute = True
                        suggestion_args.append(expected_input_spec)
                        tensor_or_list_tensor_arg_idx += 1
                    else:
                        suggestion_args.append(arg)

                suggestion_schema = None
                if needs_redistribute:
                    suggestion_schema = OpSchema(
                        op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema
                    )

                output_sharding = OutputSharding(
                    tuple(out_spec_list) if out_tensor_meta is not None else None,
                    suggestion_schema,
                    needs_redistribute=needs_redistribute,
                    use_val_from_redistribute_schema=False,
                )
            else:
                raise ValueError("Unsupported op strategy type")

            # associate the output sharding with the output tensor metadata
            new_output_spec = self._create_output_spec_with_new_tensor_meta(
                op_schema.op, output_sharding.output_spec, out_tensor_meta
            )
            output_sharding.output_spec = new_output_spec
            return output_sharding
        elif op_schema.op in self.op_to_rules:
            # propagate the sharding with rule
            sharding_prop_func = self.op_to_rules[op_schema.op]

            # step 1. there's sharding propagation rule, run
            # sharding propagation to get the output sharding
            try:
                output_sharding = sharding_prop_func(op_schema)
            except NotImplementedError as e:
                raise e
            except Exception as e:
                raise RuntimeError(
                    f"Sharding propagation failed on op {op_schema}.\nError: {e}"
                ) from e

            # step 2. if can't get output_spec from sharding
            # propagation (i.e. no rules apply for input
            # placements), we return the output sharding
            # with schema suggestions, which can be used to
            # decide how to do redistribute on inputs
            if output_sharding.output_spec is None:
                if output_sharding.redistribute_schema is None:
                    raise RuntimeError(
                        f"Sharding propagation failed on op {op_schema}!"
                    )
                else:
                    # we do auto redistribute on inputs if necessary
                    # run sharding propagation again with suggested schema
                    propagation_res = sharding_prop_func(
                        output_sharding.redistribute_schema
                    )
                    # we set the output sharding with the new propagation result
                    # so that dispatching know both output_spec and redistribute_schema
                    # exist, which indicates a reshard is needed
                    output_sharding.output_spec = propagation_res.output_spec
                    output_sharding.needs_redistribute = True

            # associate the output sharding with the output tensor metadata
            new_output_spec = self._create_output_spec_with_new_tensor_meta(
                op_schema.op, output_sharding.output_spec, out_tensor_meta
            )
            output_sharding.output_spec = new_output_spec

            return output_sharding
        else:
            raise NotImplementedError(
                f"Operator {op_schema.op} does not have a sharding strategy registered."
            ) from decomp_exception

    def _adjust_shape_and_stride_args(
        self,
        out_tensor_meta: TensorMeta,
        schema: OpSchema,
        spec: DTensorSpec,
    ) -> OpSchema:
        shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op]
        if isinstance(shape_stride_idx, tuple):
            shape_idx, stride_idx = shape_stride_idx
        else:
            shape_idx = shape_stride_idx
            stride_idx = None

        expected_input_schema = list(schema.args_schema)
        # adjust shape to be the same as that of the _local_tensor
        # of the DTensor input arg at index 0, which is inferred
        local_shape, _ = compute_local_shape_and_global_offset(
            out_tensor_meta.shape, spec.mesh, spec.placements, skip_offset=True
        )
        expected_input_schema[shape_idx] = local_shape

        # adjust the stride arg for aten.new_empty_strided.default
        if stride_idx:
            expected_input_schema[stride_idx] = compute_local_stride(
                out_tensor_meta.stride, local_shape
            )

        return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema)

    def _adjust_squeeze_to_global_singletons(self, schema: OpSchema) -> OpSchema | None:
        """
        Rewrite squeeze ops to squeeze.dims with only globally-singleton dims.
        Fixes bug where sharded dims with local size 1 get incorrectly squeezed.
        Returns None if no rewrite is needed (already squeeze.dims with correct args).
        """
        from torch.fx.experimental.symbolic_shapes import guard_or_false

        input_spec = cast(DTensorSpec, schema.args_schema[0])
        tensor_meta = input_spec.tensor_meta
        if tensor_meta is None:
            raise RuntimeError("squeeze requires tensor metadata")
        global_shape = tensor_meta.shape
        ndim = len(global_shape)

        def normalize(d: int) -> int:
            return d if d >= 0 else d + ndim

        def is_singleton(d: int) -> bool:
            nd = normalize(d)
            return 0 <= nd < ndim and guard_or_false(global_shape[nd] == 1)

        # guard_or_false: conservatively keep dims when size is symbolic/unknown
        if schema.op in (aten.squeeze.default, aten.squeeze_.default):
            target_dims = tuple(
                i for i, s in enumerate(global_shape) if guard_or_false(s == 1)
            )
        elif schema.op in (aten.squeeze.dim, aten.squeeze_.dim):
            dim = normalize(schema.args_schema[1])  # type: ignore[arg-type]
            target_dims = (dim,) if is_singleton(dim) else ()
        else:
            dims = cast(Sequence[int], schema.args_schema[1])
            target_dims = tuple(  # type: ignore[union-attr]
                normalize(d) for d in dims if is_singleton(d)
            )

        dims_variant = self.squeeze_op_to_dims_variant[schema.op]
        # Skip rewrite if already targeting the right op with the same dims
        if schema.op == dims_variant and len(schema.args_schema) > 1:
            existing_dims = schema.args_schema[1]
            if existing_dims == target_dims:
                return None
        return OpSchema(dims_variant, (input_spec, target_dims), {})
