Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / pack_utils.py: 25%
134 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Packing helpers for MindSpore fully_shard communication buffers."""
17from __future__ import annotations
19import math
20from dataclasses import dataclass
21from typing import Any, Literal, Optional
23import mindspore as ms
25from hyper_parallel.core.dtensor.placement_types import StridedShard
28@dataclass(frozen=True)
29class ReduceScatterPlan:
30 """Describe how local tensors map to packed communication layouts."""
32 pack_kind: Literal[
33 "identity_dim0",
34 "same_dim_strided_identity_dim0",
35 "chunk_cat_non_dim0",
36 ]
37 shard_dim: int
38 world_size: int
39 packed_shape: tuple[int, ...]
40 packed_tensor_shape: tuple[int, ...]
41 unpacked_shape: tuple[int, ...]
44@dataclass(frozen=True)
45class _SameDimStridedLayoutContext:
46 target_dim: int
47 shard_mesh_dim: int
48 placements: tuple[Any, ...]
49 orig_placements: tuple[Any, ...]
52def _shape_tuple(shape) -> tuple[int, ...]:
53 return tuple(int(dim) for dim in shape)
56def _has_strided_shard_layout(hsdp_param: Any) -> bool:
57 placements = getattr(hsdp_param, "_spmd_placements", ()) or ()
58 return any(isinstance(placement, StridedShard) for placement in placements)
61def _resolve_same_dim_strided_context(
62 hsdp_param: Any,
63) -> Optional[_SameDimStridedLayoutContext]:
64 if not _has_strided_shard_layout(hsdp_param):
65 return None
66 if not getattr(hsdp_param, "uses_param_shard", False):
67 return None
68 if not getattr(hsdp_param, "_orig_param_is_dtensor", False):
69 return None
70 target_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", None)
71 if target_dim is None:
72 return None
73 shard_mesh_dim = getattr(hsdp_param, "_spmd_shard_mesh_dim", None)
74 placements = tuple(getattr(hsdp_param, "_spmd_placements", ()) or ())
75 if shard_mesh_dim is None or shard_mesh_dim >= len(placements):
76 return None
77 if not isinstance(placements[shard_mesh_dim], StridedShard):
78 return None
79 orig_placements = getattr(hsdp_param, "_orig_dtensor_placements", None)
80 if orig_placements is None:
81 return None
82 return _SameDimStridedLayoutContext(
83 target_dim=target_dim,
84 shard_mesh_dim=shard_mesh_dim,
85 placements=placements,
86 orig_placements=tuple(orig_placements),
87 )
90def _placements_match_target_dim_only(
91 placements: tuple[Any, ...],
92 target_dim: int,
93) -> bool:
94 return all(
95 placement.is_replicate() or placement.is_shard(target_dim)
96 for placement in placements
97 )
100def _orig_layout_is_supported(
101 orig_placements: tuple[Any, ...],
102 target_dim: int,
103) -> bool:
104 if not _placements_match_target_dim_only(orig_placements, target_dim):
105 return False
106 return sum(
107 placement.is_shard(target_dim) for placement in orig_placements
108 ) == 1
111def _current_strided_layout_is_supported(
112 placements: tuple[Any, ...],
113 target_dim: int,
114) -> bool:
115 if not _placements_match_target_dim_only(placements, target_dim):
116 return False
117 if sum(placement.is_shard() for placement in placements) != 2:
118 return False
120 strided_placements = [
121 placement for placement in placements if isinstance(placement, StridedShard)
122 ]
123 if len(strided_placements) != 1:
124 return False
125 strided_placement = strided_placements[0]
126 if strided_placement.dim != target_dim or strided_placement.split_factor <= 1:
127 return False
129 plain_shards = [
130 placement
131 for placement in placements
132 if placement.is_shard(target_dim) and not isinstance(placement, StridedShard)
133 ]
134 return len(plain_shards) == 1
137def supports_same_dim_strided_layout(hsdp_param: Any) -> bool:
138 ctx = _resolve_same_dim_strided_context(hsdp_param)
139 if ctx is None:
140 return False
141 if not _orig_layout_is_supported(ctx.orig_placements, ctx.target_dim):
142 return False
143 return _current_strided_layout_is_supported(ctx.placements, ctx.target_dim)
146def _resolve_unpacked_shape(
147 hsdp_param: Optional[Any],
148 local_tensor: ms.Tensor,
149) -> tuple[int, ...]:
150 if hsdp_param is not None and getattr(hsdp_param, "_orig_size", None) is not None:
151 return _shape_tuple(getattr(hsdp_param, "_orig_size"))
152 return _shape_tuple(local_tensor.shape)
155def _get_packed_tensor_shape(
156 unpacked_shape: tuple[int, ...],
157 shard_dim: int,
158 world_size: int,
159) -> tuple[int, ...]:
160 if world_size == 1 or shard_dim == 0:
161 return unpacked_shape
162 packed_tensor_shape = list(unpacked_shape)
163 packed_tensor_shape[0] *= world_size
164 packed_tensor_shape[shard_dim] //= world_size
165 return tuple(packed_tensor_shape)
168def build_rs_plan(
169 hsdp_param: Optional[Any],
170 local_tensor: ms.Tensor,
171 world_size: int,
172 *,
173 shard_dim: Optional[int] = None,
174) -> ReduceScatterPlan:
175 """Build the V1 reduce-scatter packing plan for a local gradient tensor."""
177 if world_size <= 0:
178 raise ValueError(f"world_size must be positive, but got {world_size}")
180 resolved_shard_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", shard_dim)
181 if resolved_shard_dim is None:
182 raise ValueError("build_rs_plan requires either hsdp_param or shard_dim")
183 unpacked_shape = _resolve_unpacked_shape(hsdp_param, local_tensor)
184 if resolved_shard_dim < 0 or resolved_shard_dim >= len(unpacked_shape):
185 raise ValueError(
186 f"Invalid shard dim {resolved_shard_dim} for tensor shape {tuple(unpacked_shape)}"
187 )
188 if world_size == 1:
189 if not local_tensor.is_contiguous():
190 raise NotImplementedError(
191 "reduce_scatter_grad currently expects contiguous local gradients before packing."
192 )
193 return ReduceScatterPlan(
194 pack_kind="identity_dim0",
195 shard_dim=resolved_shard_dim,
196 world_size=world_size,
197 packed_shape=(1, math.prod(unpacked_shape)),
198 packed_tensor_shape=unpacked_shape,
199 unpacked_shape=unpacked_shape,
200 )
201 if len(local_tensor.shape) == 0:
202 raise NotImplementedError("reduce_scatter_grad does not support scalar gradients.")
203 if unpacked_shape[resolved_shard_dim] % world_size != 0:
204 raise NotImplementedError(
205 f"reduce_scatter_grad currently only supports even sharding on dim={resolved_shard_dim}."
206 )
207 if not local_tensor.is_contiguous():
208 raise NotImplementedError(
209 "reduce_scatter_grad currently expects contiguous local gradients before packing."
210 )
212 pack_kind: Literal[
213 "identity_dim0",
214 "same_dim_strided_identity_dim0",
215 "chunk_cat_non_dim0",
216 ] = "identity_dim0"
217 if hsdp_param is not None and _has_strided_shard_layout(hsdp_param):
218 if not supports_same_dim_strided_layout(hsdp_param):
219 raise NotImplementedError(
220 "reduce_scatter_grad only supports same-dim StridedShard layouts "
221 "that restore a single contiguous TP-local shard on the fully_shard dimension."
222 )
223 if resolved_shard_dim == 0:
224 pack_kind = "same_dim_strided_identity_dim0"
225 else:
226 pack_kind = "chunk_cat_non_dim0"
227 elif resolved_shard_dim != 0:
228 pack_kind = "chunk_cat_non_dim0"
230 packed_tensor_shape = _get_packed_tensor_shape(
231 unpacked_shape,
232 resolved_shard_dim,
233 world_size,
234 )
235 total_numel = math.prod(unpacked_shape)
236 return ReduceScatterPlan(
237 pack_kind=pack_kind,
238 shard_dim=resolved_shard_dim,
239 world_size=world_size,
240 packed_shape=(world_size, total_numel // world_size),
241 packed_tensor_shape=packed_tensor_shape,
242 unpacked_shape=unpacked_shape,
243 )
246def pack_for_reduce_scatter(
247 local_tensor: ms.Tensor,
248 plan: ReduceScatterPlan,
249) -> ms.Tensor:
250 """Pack one local gradient into the row-major reduce-scatter layout."""
252 if plan.pack_kind not in (
253 "identity_dim0",
254 "same_dim_strided_identity_dim0",
255 "chunk_cat_non_dim0",
256 ):
257 raise NotImplementedError(f"Unsupported reduce-scatter pack kind: {plan.pack_kind}")
258 if not local_tensor.is_contiguous():
259 raise NotImplementedError(
260 "reduce_scatter_grad currently expects contiguous local gradients before packing."
261 )
262 if _shape_tuple(local_tensor.shape) != plan.unpacked_shape:
263 raise AssertionError(
264 "pack_for_reduce_scatter expects the unsharded local tensor shape to match "
265 f"plan.unpacked_shape, but got {tuple(local_tensor.shape)} and "
266 f"{tuple(plan.unpacked_shape)}"
267 )
268 if plan.pack_kind == "chunk_cat_non_dim0":
269 chunks = ms.mint.chunk(local_tensor, plan.world_size, dim=plan.shard_dim)
270 packed_tensor = ms.mint.cat(chunks, dim=0)
271 return packed_tensor.contiguous().view(plan.packed_shape)
272 return local_tensor.view(plan.packed_shape)
275def unpack_from_all_gather(
276 full_packed: ms.Tensor,
277 plan: ReduceScatterPlan,
278) -> ms.Tensor:
279 """Inverse of the V1 reduce-scatter packing plan for all-gather outputs."""
281 if plan.pack_kind not in (
282 "identity_dim0",
283 "same_dim_strided_identity_dim0",
284 "chunk_cat_non_dim0",
285 ):
286 raise NotImplementedError(f"Unsupported all-gather unpack kind: {plan.pack_kind}")
287 packed_tensor = full_packed.view(plan.packed_tensor_shape)
288 if plan.pack_kind == "chunk_cat_non_dim0":
289 chunks = ms.mint.chunk(packed_tensor, plan.world_size, dim=0)
290 return ms.mint.cat(chunks, dim=plan.shard_dim).contiguous()
291 return packed_tensor.view(plan.unpacked_shape)
294__all__ = [
295 "ReduceScatterPlan",
296 "build_rs_plan",
297 "pack_for_reduce_scatter",
298 "unpack_from_all_gather",
299 "supports_same_dim_strided_layout",
300]