# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import functools
import itertools
import operator
from collections.abc import Callable, Iterable, Sequence
from typing import TypeAlias, TypeVar

import torch
from torch._prims_common import DimsSequenceType, DimsType
from torch.distributed.tensor._api import DTensor
from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
    OpSchema,
    OpSpec,
    OpStrategy,
    OutputSharding,
    PlacementList,
    RuntimeSchemaInfo,
    StrategyType,
)
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
    _is_shard_like,
    _StridedShard,
    Partial,
    Placement,
    Replicate,
    Shard,
)


def _get_registration_wrapper(
    registration_fn,
    op: torch._ops.OpOverload | list[torch._ops.OpOverload],
    schema_info: RuntimeSchemaInfo | None,
    arg_names_that_require_specializing_cache_strategy: list[str] | None,
):
    def wrapper(impl):
        overloads = op if isinstance(op, list) else [op]
        for overload in overloads:
            curr_schema_info = None
            if (
                schema_info is None
                and arg_names_that_require_specializing_cache_strategy is not None
            ):
                specialized_args = [
                    a.name
                    for a in overload._schema.arguments
                    if a.name in arg_names_that_require_specializing_cache_strategy
                ]
                if any(specialized_args):
                    curr_schema_info = RuntimeSchemaInfo(
                        static_kwargkey=specialized_args
                    )
            else:
                curr_schema_info = schema_info
            registration_fn(overload, impl, curr_schema_info)
        return impl

    return wrapper


# convenient wrapper to register sharding propagation rules
def register_prop_rule(
    op: torch._ops.OpOverload | list[torch._ops.OpOverload],
    schema_info: RuntimeSchemaInfo | None = None,
) -> Callable[
    [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding]
]:
    return _get_registration_wrapper(
        DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule,
        op,
        schema_info,
        arg_names_that_require_specializing_cache_strategy=None,
    )


# Note:
# using TypeVar here allows the registration decorator to preserve the specific type info of the wrapped strategy,
# while hardcoding the typing on the wrapper (e.g. Callable[[OpSchema], StrategyType]) would mean mypy would treat
# the return value of the wrapped strategy as always being a `StrategyType` even if it were a derived class like
# MyStrategyType(StrategyType).
_OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema)
_StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType)
_ShardingStrategyFunc: TypeAlias = Callable[[_OpSchemaT], _StrategyTypeT]


def register_op_strategy(
    op: torch._ops.OpOverload | list[torch._ops.OpOverload],
    schema_info: RuntimeSchemaInfo | None = None,
) -> Callable[[_ShardingStrategyFunc], _ShardingStrategyFunc]:
    # For every ATen op that accepts any args in this list,
    # the arg itself can impact the strides (and potentially the sharding strategy)
    # of the output tensor.
    # thus, we will detect ATen schemas with any of these args and ensure
    # that they get specialized here.
    arg_names_that_require_specializing_cache_strategy = [
        "memory_format",
    ]
    return _get_registration_wrapper(
        DTensor._op_dispatcher.sharding_propagator.register_op_strategy,
        op,
        schema_info,
        arg_names_that_require_specializing_cache_strategy,
    )


def replicate_op_strategy(op_schema: OpSchema) -> StrategyType:
    """
    Fallback strategy all use Replication()
    """
    args_strategy = op_schema.args_strategy
    kwargs_strategy = op_schema.kwargs_strategy
    inputs_strategy = args_strategy + kwargs_strategy

    output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
    output_len = output_type.count("Tensor")
    # TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
    # handling view ops until confirmed.
    if op_schema.op.is_view:
        raise RuntimeError(
            "fallback strategy is unable to handle view ops until confirmed"
        )
    if "List[Tensor]" in output_type:
        raise RuntimeError(
            "fallback strategy is unable to handle ops with List[Tensor] output "
            "because size of the list may depend on the op's input value"
        )

    mesh = inputs_strategy[0].mesh

    dim_sharding: PlacementList = [Replicate()] * (output_len + len(inputs_strategy))
    single_dim_placement = [dim_sharding]
    return expand_to_full_mesh_op_strategy(
        mesh, op_schema, single_dim_placement, input_index=output_len
    )


def as_list(
    x: list[object] | object,
    # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> list[object] | torch.fx.immutable_collections.immutable_list:  # type: ignore[valid-type]
    # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
    # which is an object but treated as a list by the tracer. Therefore, keep
    # `immutable_list` intact here as well.
    if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list):
        return x
    else:
        return [x]


def normalize_dim(dim: int, ndim: int) -> int:
    return dim if dim >= 0 else dim + ndim


def normalize_dims(dims: DimsType, ndim: int) -> DimsSequenceType:
    """Normalize a dim or a sequence of dims, so that they are all positive."""
    if isinstance(dims, int):
        dims = (normalize_dim(dims, ndim),)
    elif isinstance(dims, list):
        dims = [normalize_dim(dim, ndim) for dim in dims]
    elif isinstance(dims, tuple):
        dims = tuple(normalize_dim(dim, ndim) for dim in dims)
    return dims


def prod(xs: Iterable[int]) -> int:
    return functools.reduce(operator.mul, xs, 1)


def is_tensor_shardable(
    shape: Sequence[int],
    spec: DTensorSpec,
    allow_unbacked_sharding: bool | None = None,
) -> bool:
    """
    Check if the shape is shardable according to the spec.

    This function handles both `Shard` and `_StridedShard` placements:
    - For `Shard`: checks if the tensor dimension size >= number of shards
    - For `_StridedShard`: additionally checks if the dimension is shardable after
      splitting with the placement's `split_factor`

    allow_unbacked_sharding: determines the fallback value if unbacked shapes are involved,
    and the queried shape properties are not statically known.

    e.g. when asking if u0 is shardable on num_shards, and u0 has generic bounds [0, inf],
    the behavior of allow_unbacked_sharding is:

        None: will data-dependent error
        True: assumes shardability; we return True, allowing zero-size shards at runtime when u0 < num_shards.
        False: returns False, and lower-bounding u0, e.g. torch._check(u0 >= num_shards), is needed to enable sharding.
    """
    from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true

    if allow_unbacked_sharding not in [None, True, False]:
        raise AssertionError
    guard_fn = {
        None: bool,
        True: guard_or_false,
        False: guard_or_true,
    }[allow_unbacked_sharding]

    # number of shards in each tensor dimension
    num_shards = [1] * len(shape)
    for i, placement in enumerate(spec.placements):
        if _is_shard_like(placement):
            shard_dim = placement.dim
            if shard_dim >= len(shape):
                return False
            num_shards[shard_dim] *= spec.mesh.size(i)
            if isinstance(placement, _StridedShard):
                # make sure tensor dim `shard_dim` is shardable after splitting
                # with split_factor
                if guard_fn(
                    shape[shard_dim] < num_shards[shard_dim] * placement.split_factor
                ):
                    return False
            else:
                if guard_fn(shape[shard_dim] < num_shards[shard_dim]):
                    return False

    return True


def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
    """Check if the shape is evenly shardable according to the spec."""
    # number of shards in each tensor dimension
    num_shards = [1] * len(shape)
    for i, placement in enumerate(spec.placements):
        if _is_shard_like(placement):
            shard_dim = placement.dim
            if shard_dim >= len(shape):
                return False
            num_shards[shard_dim] *= spec.mesh.size(i)
            if isinstance(placement, _StridedShard):
                if (
                    shape[shard_dim] % (placement.split_factor * num_shards[shard_dim])
                    != 0
                ):
                    return False
            else:
                if shape[shard_dim] % num_shards[shard_dim] != 0:
                    return False
    return True


def is_tensor_evenly_shardable_on_dim(
    shape: Sequence[int], spec: DTensorSpec, dim: int
) -> bool:
    """Check if the shape is evenly shardable according to the spec on dim."""
    dim = normalize_dim(dim, len(shape))

    num_shards = 1
    for i, placement in enumerate(spec.placements):
        if _is_shard_like(placement) and placement.dim == dim:
            num_shards *= spec.mesh.size(i)
            if isinstance(placement, _StridedShard):
                # _StridedShard._split_tensor first chunks into split_factor
                # groups, then into num_shards within each group, so the dim
                # must be divisible by the product of both.  This is stricter
                # than the final num_shards check and implies it.  Note:
                # num_shards already includes spec.mesh.size(i) from this
                # iteration, so the check covers the full shard count.
                if shape[dim] % (placement.split_factor * num_shards) != 0:
                    return False

    return shape[dim] % num_shards == 0


def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool:
    """Return True if tensor dim is sharded."""
    return any(_is_shard_like(p) and p.dim == dim for p in spec.placements)


def is_tensor_partial(spec: DTensorSpec) -> bool:
    """Return True if tensor is partial on the mesh."""
    return any(p.is_partial() for p in spec.placements)


def infer_broadcast_dims_map(
    common_shape: torch.Size, input_shape: torch.Size
) -> list[int]:
    # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
    # this is aligned with the broadcast semantics
    # e.g. if common_shape = [1, 2, 3, 4] and input_shape = [2, 3, 4],
    # broadcast_dims_map will be [-1, 0, 1, 2]
    # meaning that dim 0 in the output has no mapping to the input, and dim 1 in the output maps to dim 0 in the input
    from torch.fx.experimental.symbolic_shapes import guard_or_false

    common_ndim = len(common_shape)
    input_ndim = len(input_shape)
    broadcast_dims_map = [-1] * common_ndim
    for idx in range(-1, -1 - input_ndim, -1):
        if guard_or_false(input_shape[idx] == common_shape[idx]):
            broadcast_dims_map[common_ndim + idx] = input_ndim + idx
    return broadcast_dims_map


def map_placements_after_broadcast(
    placements: tuple[Placement, ...],
    shape: torch.Size,
    broadcast_dims_map: list[int],
    partial_to_replicate: bool = False,
) -> tuple[Placement, ...]:
    """Map each placement based on the output shape after broadcast."""
    new_placements: list[Placement] = []
    for placement in placements:
        if isinstance(placement, Partial):
            if partial_to_replicate:
                # map the partial placement to replicate
                new_placements.append(Replicate())
            else:
                new_placements.append(placement)
        elif isinstance(placement, Replicate):
            new_placements.append(placement)
        else:
            if not _is_shard_like(placement):
                raise AssertionError
            shard_dim = normalize_dim(placement.dim, len(shape))
            new_shard_dim = broadcast_dims_map[shard_dim]
            if new_shard_dim != -1:
                # there's a map from the common shape shard dim to
                # the input shape shard dim before broadcasting,
                # use that instead
                if isinstance(placement, _StridedShard):
                    new_placements.append(
                        _StridedShard(
                            new_shard_dim, split_factor=placement.split_factor
                        )
                    )
                else:
                    new_placements.append(Shard(new_shard_dim))
            else:
                # there's no map between common shape shard dim and
                # the input shape shard dim before broadcasting,
                # in this case it means implicit broadcasting happen
                # in this dim, so we can just mark it as replicate
                # and implicit broadcast will broadcast automatically
                # to the sharded shape
                new_placements.append(Replicate())

    return tuple(new_placements)


def generate_redistribute_costs(
    src_strategy: OpStrategy, dst_spec: DTensorSpec
) -> list[float]:
    """Generates one row in the 'redistribute_costs' matrix in an OpSpec
    The length of the returned list will match the number of strategies in 'src_strategy'.

    Each value in the row is the cost of redistributing from a particular src_strategy to dst_spec.
    """
    redistribute_costs: list[float] = [
        redistribute_cost(strat.output_spec, dst_spec)
        for strat in src_strategy.strategies
    ]

    return redistribute_costs


def expand_to_full_mesh_op_strategy(
    mesh: DeviceMesh,
    op_schema: OpSchema,
    single_mesh_dim_strategies: list[PlacementList],
    *,
    output_tensor_meta: TensorMeta | Sequence[TensorMeta | None] | None = None,
    input_index: int = 1,
    inplace_op: bool = False,
    allow_unbacked_sharding: bool | None = None,
    allow_uneven_sharding: bool = False,
    is_valid_strategy_cb: Callable[
        [list[DTensorSpec], DTensorSpec | tuple[DTensorSpec | None, ...]], bool
    ]
    | None = None,
    different_mesh_args: list[int] | None = None,
) -> OpStrategy:
    """
    Convenience function to allow writing a sharding strategy considering only a single mesh dimension,
    and have it expanded combinatorially to all mesh dimensions.

    Args:
        mesh (DeviceMesh): the device mesh to expand the strategy to
        op_schema (OpSchema): the op schema
        single_mesh_dim_strategies (list[PlacementList]): the sharding strategies to expand. The outer list is over
            different strategies.  The inner PlacementList is over the outputs and inputs of the op. If input_index is 1,
            a PlacementList looks like [output_placement, input_placement1, input_placement2, ...].
        output_tensor_meta: tensor metadata for the output(s), used to populate DTensorSpec.tensor_meta field
        input_index: the number of outputs of the op, defaults to 1
        inplace_op: whether the op is inplace or not, defaults to False
        is_valid_strategy_cb: a callback function to filter out invalid sharding rules, defaults to None.

    Example: Let's say `my_op(tensor_x, tensor_y) - > output_tensor`  can support sharding or replicating tensor_x,
    but always requires tensor_y to be replicated.  We can specify these valid combinations ignoring mesh dims.
    Then, we can rely on `expand_to_full_mesh_op_strategy` to create every possible combination of these shardings
    over multiple mesh dimensions, filtering out any combinations that are invalid based on the actual mesh dim size.

        single_mesh_dim_strategies = [
            # first strategy: return output sharded on first dim, shard tensor_x on its first dim, replicate tensor_y
            [Shard(0), Shard(0), Replicate()]
            # second strategy: replicate output, and both inputs
            [Replicate(), Replicate(), Replicate()]
        ]
    """
    # Expand the single_mesh_dim_strategies to full mesh dim strategies.
    all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim

    strategy_combs = itertools.product(*all_mesh_dim_strategies)

    args_strategy = op_schema.args_strategy
    kwargs_strategy = op_schema.kwargs_strategy
    input_args_strategy = args_strategy + kwargs_strategy

    # Propagate use_strided_shard_as_shard_order from inputs so that
    # strategy specs with _StridedShard get the correct flag (and thus
    # correct shard_order) at construction time, avoiding shard_order
    # mismatches in redistribute_cost computation.
    _input_use_strided: bool | None = None
    for input_strat in input_args_strategy:
        input_spec = input_strat.strategies[0].output_spec
        if any(isinstance(p, _StridedShard) for p in input_spec.placements):
            _input_use_strided = input_spec.use_strided_shard_as_shard_order
            break

    all_strategies = []
    # Track input placements if we skip strategies due to inplace placement mismatch
    blocking_inplace_input_placements: tuple[Placement, ...] | None = None
    for strategy_comb in strategy_combs:
        spec_list: list[DTensorSpec | None] = []
        # Track how many non-None output specs we've seen (for output_tensor_meta indexing).
        # This is needed because output_tensor_meta may contain only non-None entries,
        # so we can't use position directly when there are None entries in the output.
        output_spec_count = 0
        # Track input args separately since not all tensor inputs have OpStrategy
        # (e.g., philox_seed/offset in SDPA are scalar tensors without OpStrategy)
        input_strategy_counter = 0
        for position, specs in enumerate(zip(*strategy_comb, strict=True)):
            if specs[0] is not None:
                # Populate tensor_meta field for both output and input specs,
                # including for tuple output cases
                tensor_meta = None
                # Use position to determine output vs input territory
                # (position includes None entries, unlike the old spec_index)
                if position < input_index:
                    # This is an output position
                    if output_tensor_meta is not None:
                        if isinstance(output_tensor_meta, TensorMeta):
                            tensor_meta = output_tensor_meta
                        elif isinstance(output_tensor_meta, (tuple, list)):
                            if output_spec_count < len(output_tensor_meta):
                                tensor_meta = output_tensor_meta[output_spec_count]
                    output_spec_count += 1
                else:
                    # This is an input position
                    # Only get tensor_meta if we have a corresponding input_args_strategy entry
                    if input_strategy_counter < len(input_args_strategy):
                        tensor_meta = input_args_strategy[
                            input_strategy_counter
                        ].tensor_meta
                        input_strategy_counter += 1

                # pyrefly: ignore [bad-argument-type]
                use_strided = (
                    _input_use_strided
                    if _input_use_strided is not None
                    and any(isinstance(p, _StridedShard) for p in specs)
                    else None
                )
                spec_list.append(
                    DTensorSpec(
                        mesh,
                        specs,
                        tensor_meta=tensor_meta,
                        use_strided_shard_as_shard_order=use_strided,
                    )
                )
            else:
                spec_list.append(None)

        # Skip strategy combinations that would create mixed partial types
        # (except sum+avg which commute with each other).
        # We check (type, reduce_op) pairs rather than just reduce_op because
        # Partial subclasses like _MaskPartial have different reduction semantics
        # even when they share the same reduce_op string.
        has_mixed_partial = False
        for spec in spec_list:
            if spec is not None:
                partial_kinds = {
                    (type(p), p.reduce_op)
                    for p in spec.placements
                    if isinstance(p, Partial)
                }
                if len(partial_kinds) > 1:
                    reduce_ops = {ro for _, ro in partial_kinds}
                    types = {t for t, _ in partial_kinds}
                    if not (len(types) == 1 and reduce_ops == {"sum", "avg"}):
                        has_mixed_partial = True
                        break
        if has_mixed_partial:
            continue

        input_specs: list[DTensorSpec] = [
            s for s in spec_list[input_index:] if isinstance(s, DTensorSpec)
        ]

        if len(input_specs) != len(input_args_strategy):
            raise AssertionError(
                f"input_specs({len(input_specs)}) != strategies({len(input_args_strategy)}: "
                f"{len(args_strategy)} args + {len(kwargs_strategy)} kwargs)"
            )

        # Note [Multi-mesh args]
        #
        # Some ops accept args whose DTensor lives on a different DeviceMesh
        # than the op's primary compute mesh.  We call these "multi-mesh
        # args".  They arise in fused optimizer ops (e.g. _fused_adam_)
        # where *state_steps* is a per-rank scalar counter allocated on a
        # smaller sub-mesh (e.g. 1-D DP) while params and grads live on a
        # larger mesh (e.g. 2-D DP × TP).
        #
        # Why must these args be Replicate?
        #   Sharding implies a specific partitioning of a tensor's data
        #   across the ranks of a mesh.  If a tensor doesn't even *exist*
        #   on the compute mesh, there is no meaningful way to interpret a
        #   Shard placement for it.  Replicate, on the other hand, is
        #   mesh-agnostic: every rank already holds the full data, so the
        #   op can simply read the value regardless of which mesh owns it.
        #
        # What we do here:
        #   We preserve the original mesh and Replicate placement for these
        #   args so the propagator does not try to redistribute them onto
        #   the compute mesh (which would fail or produce wrong results).
        #
        # This is distinct from the *element_mesh* handling in
        # single_dim_strategy.py, which deals with foreach ops where
        # different *elements* in a tensor list may live on different
        # sub-meshes (e.g. param group A on 2-D mesh, param group B on
        # 1-D mesh).
        # TODO: refactor fused_ops handling so that there are no longer
        # args on different meshes
        if different_mesh_args is not None:
            for idx in different_mesh_args:
                if idx < len(input_args_strategy):
                    cross_mesh_input = input_args_strategy[idx]
                    original_spec = cross_mesh_input.strategies[0].output_spec
                    if original_spec.mesh != mesh:
                        if not all(p == Replicate() for p in original_spec.placements):
                            raise RuntimeError(
                                f"Cross-mesh input at index {idx} must be Replicate, "
                                f"but got {original_spec.placements}"
                            )
                        input_specs[idx] = DTensorSpec(
                            mesh=original_spec.mesh,
                            placements=original_spec.placements,
                            tensor_meta=original_spec.tensor_meta,
                        )
        self_spec = input_args_strategy[0].strategies[0].output_spec

        redistribute_input = self_spec.placements != input_specs[0].placements
        mismatching_input_output = (
            spec_list[0] is not None and spec_list[0].placements != self_spec.placements
        )
        if inplace_op and (redistribute_input or mismatching_input_output):
            # For inplace ops, both the proposed input[0] and the output must
            # match self's runtime placement: input[0] because self can't be
            # redistributed, output because the result IS self.
            if blocking_inplace_input_placements is None:
                blocking_inplace_input_placements = self_spec.placements
            continue

        # For out= variant ops, output placement must match the "out" kwarg's placement
        if (
            op_schema.is_out_variant_op()
            and "out" in op_schema.kwargs_schema
            and isinstance(op_schema.kwargs_schema["out"], OpStrategy)
        ):
            out_kwarg_spec = op_schema.kwargs_schema["out"].strategies[0].output_spec
            # spec_list[0] is the output spec for this strategy combination
            if spec_list[0] is not None:
                if spec_list[0].placements != out_kwarg_spec.placements:
                    continue

        output_specs: tuple[DTensorSpec | None, ...] | DTensorSpec | None
        if input_index == 0:
            # No outputs (e.g., _linalg_check_errors)
            output_specs = None
        elif input_index > 1:
            output_specs = tuple(spec_list[:input_index])
        else:
            if spec_list[0] is not None:
                output_specs = spec_list[0]
            else:
                raise RuntimeError("output spec is None")

        # check all inputs are shardable
        if not all(
            is_tensor_shardable(
                inp.shape, s, allow_unbacked_sharding=allow_unbacked_sharding
            )
            or (
                allow_uneven_sharding
                and inp.strategies[0].output_spec.placements == s.placements
            )
            for inp, s in zip(input_args_strategy, input_specs)
        ):
            continue

        # perform additional op-specific filtering
        # Skip callback for no-output ops (output_specs is None)
        if is_valid_strategy_cb is not None and output_specs is not None:
            if not is_valid_strategy_cb(input_specs, output_specs):
                continue

        redistribute_cost = [
            generate_redistribute_costs(input_strategy, input_spec)
            for input_strategy, input_spec in zip(input_args_strategy, input_specs)
        ]

        strategy = OpSpec(
            output_specs=output_specs,
            input_specs=input_specs,
            redistribute_cost=redistribute_cost,
        )
        all_strategies.append(strategy)

    # If all strategies were filtered out due to inplace placement mismatch,
    # raise a clear error message instead of returning an empty OpStrategy
    # (which would later cause a cryptic "min() arg is an empty sequence" error)
    if not all_strategies and blocking_inplace_input_placements is not None:
        raise RuntimeError(
            f"{op_schema.op}: in-place operations that require placement changes "
            f"are not supported. The input has placement {blocking_inplace_input_placements}, "
            f"but no valid strategy preserves this placement. "
            f"Please use the out-of-place version of this operation instead."
        )

    return OpStrategy(all_strategies)


def shift_shard_dims_after_insert(
    placements: Sequence[Placement], insert_dim: int = 0
) -> Sequence[Placement]:
    normalized_placements: list[Placement] = []
    for placement in placements:
        if isinstance(placement, _StridedShard) and placement.dim >= insert_dim:
            normalized_placements.append(
                _StridedShard(placement.dim + 1, split_factor=placement.split_factor)
            )
        elif isinstance(placement, Shard) and placement.dim >= insert_dim:
            normalized_placements.append(Shard(placement.dim + 1))
        else:
            normalized_placements.append(placement)
    return normalized_placements


def shift_shard_dims_after_remove(
    placements: Sequence[Placement], remove_dim: int = 0
) -> Sequence[Placement]:
    normalized_placements: list[Placement] = []
    for placement in placements:
        if isinstance(placement, _StridedShard) and placement.dim > remove_dim:
            normalized_placements.append(
                _StridedShard(placement.dim - 1, split_factor=placement.split_factor)
            )
        elif isinstance(placement, Shard) and placement.dim > remove_dim:
            normalized_placements.append(Shard(placement.dim - 1))
        else:
            normalized_placements.append(placement)
    return normalized_placements
