Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / pack_utils.py: 79%
132 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Packing helpers for fully_shard communication buffers."""
17from __future__ import annotations
19import math
20from dataclasses import dataclass
21from typing import Any, Literal, Optional
23import torch
25from hyper_parallel.core.dtensor.placement_types import StridedShard
28@dataclass(frozen=True)
29class ReduceScatterPlan:
30 """Describe how local tensors map to packed communication layouts."""
32 pack_kind: Literal[
33 "identity_dim0",
34 "same_dim_strided_identity_dim0",
35 "chunk_cat_non_dim0",
36 ]
37 shard_dim: int
38 world_size: int
39 packed_shape: torch.Size
40 packed_tensor_shape: torch.Size
41 unpacked_shape: torch.Size
44@dataclass(frozen=True)
45class _SameDimStridedLayoutContext:
46 target_dim: int
47 shard_mesh_dim: int
48 placements: tuple[Any, ...]
49 orig_placements: tuple[Any, ...]
52def _has_strided_shard_layout(hsdp_param: Any) -> bool:
53 placements = getattr(hsdp_param, "_spmd_placements", ()) or ()
54 return any(isinstance(placement, StridedShard) for placement in placements)
57def _resolve_same_dim_strided_context(
58 hsdp_param: Any,
59) -> Optional[_SameDimStridedLayoutContext]:
60 if not _has_strided_shard_layout(hsdp_param):
61 return None
62 if not getattr(hsdp_param, "uses_param_shard", False):
63 return None
64 if not getattr(hsdp_param, "_orig_param_is_dtensor", False):
65 return None
66 target_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", None)
67 if target_dim is None:
68 return None
69 shard_mesh_dim = getattr(hsdp_param, "_spmd_shard_mesh_dim", None)
70 placements = tuple(getattr(hsdp_param, "_spmd_placements", ()) or ())
71 if shard_mesh_dim is None or shard_mesh_dim >= len(placements):
72 return None
73 if not isinstance(placements[shard_mesh_dim], StridedShard):
74 return None
75 orig_placements = getattr(hsdp_param, "_orig_dtensor_placements", None)
76 if orig_placements is None:
77 return None
78 return _SameDimStridedLayoutContext(
79 target_dim=target_dim,
80 shard_mesh_dim=shard_mesh_dim,
81 placements=placements,
82 orig_placements=tuple(orig_placements),
83 )
86def _placements_match_target_dim_only(
87 placements: tuple[Any, ...],
88 target_dim: int,
89) -> bool:
90 return all(
91 placement.is_replicate() or placement.is_shard(target_dim)
92 for placement in placements
93 )
96def _orig_layout_is_supported(
97 orig_placements: tuple[Any, ...],
98 target_dim: int,
99) -> bool:
100 if not _placements_match_target_dim_only(orig_placements, target_dim):
101 return False
102 return sum(
103 placement.is_shard(target_dim) for placement in orig_placements
104 ) == 1
107def _current_strided_layout_is_supported(
108 placements: tuple[Any, ...],
109 target_dim: int,
110) -> bool:
111 if not _placements_match_target_dim_only(placements, target_dim):
112 return False
113 if sum(placement.is_shard() for placement in placements) != 2:
114 return False
116 strided_placements = [
117 placement for placement in placements if isinstance(placement, StridedShard)
118 ]
119 if len(strided_placements) != 1:
120 return False
121 strided_placement = strided_placements[0]
122 if strided_placement.dim != target_dim or strided_placement.split_factor <= 1:
123 return False
125 plain_shards = [
126 placement
127 for placement in placements
128 if placement.is_shard(target_dim) and not isinstance(placement, StridedShard)
129 ]
130 return len(plain_shards) == 1
133def supports_same_dim_strided_layout(hsdp_param: Any) -> bool:
134 ctx = _resolve_same_dim_strided_context(hsdp_param)
135 if ctx is None:
136 return False
137 if not _orig_layout_is_supported(ctx.orig_placements, ctx.target_dim):
138 return False
139 return _current_strided_layout_is_supported(ctx.placements, ctx.target_dim)
142def _resolve_unpacked_shape(
143 hsdp_param: Optional[Any],
144 local_tensor: torch.Tensor,
145) -> torch.Size:
146 if hsdp_param is not None and getattr(hsdp_param, "_orig_size", None) is not None:
147 return torch.Size(getattr(hsdp_param, "_orig_size"))
148 return torch.Size(local_tensor.size())
151def _get_packed_tensor_shape(
152 unpacked_shape: torch.Size,
153 shard_dim: int,
154 world_size: int,
155) -> torch.Size:
156 if world_size == 1 or shard_dim == 0:
157 return unpacked_shape
158 packed_tensor_shape = list(unpacked_shape)
159 packed_tensor_shape[0] *= world_size
160 packed_tensor_shape[shard_dim] //= world_size
161 return torch.Size(packed_tensor_shape)
164def build_rs_plan(
165 hsdp_param: Optional[Any],
166 local_tensor: torch.Tensor,
167 world_size: int,
168 *,
169 shard_dim: Optional[int] = None,
170) -> ReduceScatterPlan:
171 """Build the V1 reduce-scatter packing plan for a local gradient tensor."""
173 if world_size <= 0:
174 raise ValueError(f"world_size must be positive, but got {world_size}")
176 resolved_shard_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", shard_dim)
177 if resolved_shard_dim is None:
178 raise ValueError("build_rs_plan requires either hsdp_param or shard_dim")
179 unpacked_shape = _resolve_unpacked_shape(hsdp_param, local_tensor)
180 if resolved_shard_dim < 0 or resolved_shard_dim >= len(unpacked_shape):
181 raise ValueError(
182 f"Invalid shard dim {resolved_shard_dim} for tensor shape {tuple(unpacked_shape)}"
183 )
184 if world_size == 1:
185 if not local_tensor.is_contiguous():
186 raise NotImplementedError(
187 "reduce_scatter_grad currently expects contiguous local gradients before packing."
188 )
189 return ReduceScatterPlan(
190 pack_kind="identity_dim0",
191 shard_dim=resolved_shard_dim,
192 world_size=world_size,
193 packed_shape=torch.Size((1, math.prod(unpacked_shape))),
194 packed_tensor_shape=unpacked_shape,
195 unpacked_shape=unpacked_shape,
196 )
197 if local_tensor.dim() == 0:
198 raise NotImplementedError("reduce_scatter_grad does not support scalar gradients.")
199 if unpacked_shape[resolved_shard_dim] % world_size != 0:
200 raise NotImplementedError(
201 f"reduce_scatter_grad currently only supports even sharding on dim={resolved_shard_dim}."
202 )
203 if not local_tensor.is_contiguous():
204 raise NotImplementedError(
205 "reduce_scatter_grad currently expects contiguous local gradients before packing."
206 )
208 pack_kind: Literal[
209 "identity_dim0",
210 "same_dim_strided_identity_dim0",
211 "chunk_cat_non_dim0",
212 ] = "identity_dim0"
213 if hsdp_param is not None and _has_strided_shard_layout(hsdp_param):
214 if not supports_same_dim_strided_layout(hsdp_param):
215 raise NotImplementedError(
216 "reduce_scatter_grad only supports same-dim StridedShard layouts "
217 "that restore a single contiguous TP-local shard on the fully_shard dimension."
218 )
219 if resolved_shard_dim == 0:
220 pack_kind = "same_dim_strided_identity_dim0"
221 else:
222 pack_kind = "chunk_cat_non_dim0"
223 elif resolved_shard_dim != 0:
224 pack_kind = "chunk_cat_non_dim0"
226 packed_tensor_shape = _get_packed_tensor_shape(
227 unpacked_shape,
228 resolved_shard_dim,
229 world_size,
230 )
231 total_numel = math.prod(unpacked_shape)
233 return ReduceScatterPlan(
234 pack_kind=pack_kind,
235 shard_dim=resolved_shard_dim,
236 world_size=world_size,
237 packed_shape=torch.Size((world_size, total_numel // world_size)),
238 packed_tensor_shape=packed_tensor_shape,
239 unpacked_shape=unpacked_shape,
240 )
243def pack_for_reduce_scatter(
244 local_tensor: torch.Tensor,
245 plan: ReduceScatterPlan,
246) -> torch.Tensor:
247 """Pack one local gradient into the row-major reduce-scatter layout."""
249 if plan.pack_kind not in (
250 "identity_dim0",
251 "same_dim_strided_identity_dim0",
252 "chunk_cat_non_dim0",
253 ):
254 raise NotImplementedError(f"Unsupported reduce-scatter pack kind: {plan.pack_kind}")
255 if not local_tensor.is_contiguous():
256 raise NotImplementedError(
257 "reduce_scatter_grad currently expects contiguous local gradients before packing."
258 )
259 if local_tensor.size() != plan.unpacked_shape:
260 raise AssertionError(
261 "pack_for_reduce_scatter expects the unsharded local tensor shape to match "
262 f"plan.unpacked_shape, but got {tuple(local_tensor.size())} and "
263 f"{tuple(plan.unpacked_shape)}"
264 )
265 if plan.pack_kind == "chunk_cat_non_dim0":
266 chunks = torch.chunk(local_tensor, plan.world_size, dim=plan.shard_dim)
267 packed_tensor = torch.cat(chunks, dim=0)
268 return packed_tensor.contiguous().view(plan.packed_shape)
269 return local_tensor.view(plan.packed_shape)
272def unpack_from_all_gather(
273 full_packed: torch.Tensor,
274 plan: ReduceScatterPlan,
275) -> torch.Tensor:
276 """Inverse of the V1 reduce-scatter packing plan for all-gather outputs."""
278 if plan.pack_kind not in (
279 "identity_dim0",
280 "same_dim_strided_identity_dim0",
281 "chunk_cat_non_dim0",
282 ):
283 raise NotImplementedError(f"Unsupported all-gather unpack kind: {plan.pack_kind}")
284 packed_tensor = full_packed.view(plan.packed_tensor_shape)
285 if plan.pack_kind == "chunk_cat_non_dim0":
286 chunks = torch.chunk(packed_tensor, plan.world_size, dim=0)
287 return torch.cat(chunks, dim=plan.shard_dim).contiguous()
288 return packed_tensor.view(plan.unpacked_shape)
291__all__ = [
292 "ReduceScatterPlan",
293 "build_rs_plan",
294 "pack_for_reduce_scatter",
295 "unpack_from_all_gather",
296 "supports_same_dim_strided_layout",
297]