Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_param.py: 75%
150 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP parameter"""
17from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
18from hyper_parallel.core.dtensor.dtensor import DTensor
19from hyper_parallel.core.dtensor.placement_types import Replicate
20from hyper_parallel.core.fully_shard.hsdp_utils import (
21 FullyShardParamMode,
22 GroupInfo,
23 get_rank_list_for_axes,
24 get_split_rank_lists_for_axes,
25)
26from hyper_parallel.core.fully_shard.utils import DDPMeshInfo, FSDPMeshInfo
27from hyper_parallel.platform import get_platform
29platform = get_platform()
30_GROUP_INFO_CACHE = {}
33def _build_group_info_from_rank_list(group_name: str, rank_list) -> GroupInfo:
34 """Create group metadata from an explicit rank list."""
35 normalized_rank_list = tuple(sorted(int(rank) for rank in rank_list))
36 if len(normalized_rank_list) <= 1:
37 return GroupInfo(f"{group_name}_invalid", None, 1)
38 if normalized_rank_list in _GROUP_INFO_CACHE:
39 cached_group = _GROUP_INFO_CACHE[normalized_rank_list]
40 return GroupInfo(str(normalized_rank_list), cached_group, len(normalized_rank_list))
41 try:
42 group = platform.create_group(list(normalized_rank_list))
43 except (RuntimeError, ValueError): # pragma: no cover - UT may run without dist init
44 group = None
45 _GROUP_INFO_CACHE[normalized_rank_list] = group
46 return GroupInfo(str(normalized_rank_list), group, len(normalized_rank_list))
49def _build_group_info_from_process_group(
50 group_name: str,
51 process_group,
52 rank_size: int,
53 *,
54 resolved_group_name: str | None = None,
55) -> GroupInfo:
56 """Create group metadata from an existing process group."""
57 if process_group is None or rank_size <= 1:
58 return GroupInfo(f"{group_name}_invalid", None, 1)
59 return GroupInfo(resolved_group_name or group_name, process_group, rank_size)
62class HSDPParamV2:
63 """
64 HSDP parameter.
65 """
67 def __init__(
68 self,
69 param,
70 module_info,
71 mesh_info,
72 post_forward_mesh_info,
73 shard_placement_fn,
74 mp_policy,
75 offload_policy,
76 threshold,
77 ):
78 """
79 Initialize HSDPParamV2.
81 Args:
82 param (nn.Parameter): The original parameter to shard.
83 module_info (ParamModuleInfo): Ownership and shared-weight metadata for the parameter.
84 mesh_info (FSDPMeshInfo): Mesh topology describing shard/replicate dimensions.
85 post_forward_mesh_info: Mesh info used after forward (reserved for subclass use).
86 shard_placement_fn (Callable, optional): Returns a Shard placement for the parameter,
87 or None to use default (Shard(0)).
88 mp_policy (MixedPrecisionPolicy, optional): Mixed precision dtype policy.
89 offload_policy (OffloadPolicy, optional): CPU offload policy.
90 threshold: Minimum parameter size to enable sharding (reserved for subclass use).
91 """
92 raise NotImplementedError("HSDP param subclasses must implement __init__")
94 def _init_sharded_param(self, param, shard_placement_fn):
95 """add and init sharded param"""
96 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_param")
98 def init_dtype_attrs(self, mp_policy):
99 """Initialize dtype attributes from mixed precision policy."""
100 raise NotImplementedError("HSDP param subclasses must implement init_dtype_attrs")
102 def init_all_gather_outputs(
103 self, all_gather_input_numels, all_gather_input_dtypes, world_size, device, force_recreate=False
104 ):
105 """Allocate or reuse output buffers for all-gather communication."""
106 raise NotImplementedError("HSDP param subclasses must implement init_all_gather_outputs")
108 def init_unsharded_param(self):
109 """Reconstruct the full unsharded parameter from all-gather outputs."""
110 raise NotImplementedError("HSDP param subclasses must implement init_unsharded_param")
112 def to_sharded(self):
113 """Transition parameter from unsharded back to sharded state and free unsharded storage."""
114 raise NotImplementedError("HSDP param subclasses must implement to_sharded")
116 def to_unsharded(self):
117 """Transition parameter to unsharded state after all-gather completes."""
118 raise NotImplementedError("HSDP param subclasses must implement to_unsharded")
120 def to_sharded_dtensor(self, tensor):
121 """Wrap a local sharded tensor as a DTensor with the correct mesh and placements."""
122 raise NotImplementedError("HSDP param subclasses must implement to_sharded_dtensor")
124 def to_accumulated_grad_if_needed(self):
125 """Move unsharded grad to accumulated grad buffer if dtype conversion is required."""
126 raise NotImplementedError("HSDP param subclasses must implement to_accumulated_grad_if_needed")
128 def accumulate_unsharded_grad_if_needed(self):
129 """Accumulate unsharded param grad into accumulated grad buffer if both exist."""
130 raise NotImplementedError("HSDP param subclasses must implement accumulate_unsharded_grad_if_needed")
132 def alloc_all_gather_outputs(self):
133 """Resize all-gather output buffers to their full capacity for communication."""
134 raise NotImplementedError("HSDP param subclasses must implement alloc_all_gather_outputs")
136 def free_unsharded_param(self):
137 """Release storage of all-gather outputs and inner tensors to free device memory."""
138 raise NotImplementedError("HSDP param subclasses must implement free_unsharded_param")
140 @property
141 def all_gather_inputs(self):
142 """Return the local sharded tensor(s) to use as input for all-gather communication."""
143 raise NotImplementedError("HSDP param subclasses must implement all_gather_inputs")
145 @property
146 def unsharded_param(self):
147 """Return the full unsharded parameter after all-gather."""
148 raise NotImplementedError("HSDP param subclasses must implement unsharded_param")
150 @property
151 def unsharded_grad_data(self):
152 """Return the unsharded_param.grad."""
153 raise NotImplementedError("HSDP param subclasses must implement unsharded_grad_data")
155 @property
156 def unsharded_accumulated_grad_data(self):
157 """Return the unsharded accumulated gradient buffer."""
158 raise NotImplementedError("HSDP param subclasses must implement unsharded_accumulated_grad_data")
160 @property
161 def _sharded_local_tensor(self):
162 """Return the underlying local tensor of the sharded DTensor parameter."""
163 raise NotImplementedError("HSDP param subclasses must implement _sharded_local_tensor")
165 def _get_unsharded_param_data(self, async_op=False):
166 """Perform all-gather to obtain unsharded parameter data, returning (tensor, handle)."""
167 raise NotImplementedError("HSDP param subclasses must implement _get_unsharded_param_data")
169 def unshard(self, async_op=False):
170 """Trigger all-gather to unshard the parameter, optionally asynchronously."""
171 raise NotImplementedError("HSDP param subclasses must implement unshard")
173 def wait_for_unshard(self):
174 """Wait for all-gather to complete and transition parameter to unsharded state."""
175 raise NotImplementedError("HSDP param subclasses must implement wait_for_unshard")
177 def shard(self):
178 """Transition parameter from unsharded back to sharded state."""
179 raise NotImplementedError("HSDP param subclasses must implement shard")
181 def reduce_scatter_grad(self):
182 """Perform reduce-scatter on the unsharded gradient to produce a sharded gradient."""
183 raise NotImplementedError("HSDP param subclasses must implement reduce_scatter_grad")
185 def all_reduce_grad(self):
186 """Perform all-reduce on gradient across the replicate dimension (HSDP mode only)."""
187 raise NotImplementedError("HSDP param subclasses must implement all_reduce_grad")
189 def _resolve_process_group_name(self, group_name: str, process_group) -> str:
190 """Resolve the name recorded in GroupInfo for an existing process group."""
191 del process_group
192 return group_name
194 def _get_base_spmd_placements(self) -> tuple:
195 """Return placements before explicit data-parallel semantics are applied."""
196 if (
197 getattr(self, "param_mode", None) == FullyShardParamMode.DTENSOR_UNIFIED
198 and getattr(self, "_orig_param_is_dtensor", False)
199 ):
200 self._spmd_mesh = DeviceMesh.concatenate([self.mesh_info.mesh, self._orig_dtensor_mesh])
201 dp_prefix_placements = tuple(Replicate() for _ in range(self.mesh_info.mesh.ndim))
202 return dp_prefix_placements + tuple(self._orig_dtensor_placements)
204 if (
205 getattr(self, "param_mode", None) == FullyShardParamMode.DTENSOR_COMPAT
206 and getattr(self, "_orig_param_is_dtensor", False)
207 ):
208 self._spmd_mesh = self._orig_dtensor_mesh
209 return tuple(self._orig_dtensor_placements)
211 self._spmd_mesh = self.mesh_info.mesh
212 return tuple(Replicate() for _ in range(self._spmd_mesh.ndim))
214 def _get_data_parallel_shard_placement(self, placements: list, shard_placement):
215 """Return the placement to apply on the explicit fully_shard dimension."""
216 del placements
217 return shard_placement
219 def _apply_data_parallel_placements(self, placements: list, shard_placement) -> tuple:
220 """Apply explicit DDP/FSDP placements on top of the base SPMD layout."""
221 if len(placements) != self._spmd_mesh.ndim:
222 raise AssertionError(
223 f"Expected {self._spmd_mesh.ndim} unified placements, got {len(placements)}: {placements}"
224 )
226 spmd_replicate_mesh_dim = getattr(self, "_spmd_replicate_mesh_dim", None)
227 if (
228 isinstance(self.mesh_info, DDPMeshInfo)
229 and spmd_replicate_mesh_dim is not None
230 and not getattr(self, "_orig_param_is_dtensor", False)
231 ):
232 placements[spmd_replicate_mesh_dim] = Replicate()
234 spmd_shard_mesh_dim = getattr(self, "_spmd_shard_mesh_dim", None)
235 if (
236 getattr(self, "uses_param_shard", False)
237 and isinstance(self.mesh_info, FSDPMeshInfo)
238 and spmd_shard_mesh_dim is not None
239 ):
240 placements[spmd_shard_mesh_dim] = self._get_data_parallel_shard_placement(
241 placements, shard_placement
242 )
243 return tuple(placements)
245 def _init_group_infos(self) -> None:
246 """Initialize sharded/unsharded communication groups from the current layout."""
247 if (
248 getattr(self, "uses_param_shard", False)
249 and getattr(self, "is_sharded", False)
250 and isinstance(self.mesh_info, FSDPMeshInfo)
251 ):
252 resolved_group_name = self._resolve_process_group_name(
253 "fully_shard_sharded_group",
254 self.mesh_info.shard_process_group,
255 )
256 self.sharded_group_info = _build_group_info_from_process_group(
257 "fully_shard_sharded_group",
258 self.mesh_info.shard_process_group,
259 self.mesh_info.shard_mesh_size,
260 resolved_group_name=resolved_group_name,
261 )
262 else:
263 self.sharded_group_info = GroupInfo("fully_shard_sharded_group_invalid", None, 1)
265 self.unsharded_group_info = self._build_layout_driven_group_info()
266 self.shard_size = self.sharded_group_info.rank_size
267 self.dp_size = self.unsharded_group_info.rank_size
268 self.rank_size = max(1, self.shard_size * self.dp_size)
270 def _build_layout_driven_group_info(self) -> GroupInfo:
271 """Build the group that should all-reduce an unsharded gradient from the final layout."""
272 group_axes = [
273 axis
274 for axis, placement in enumerate(self._spmd_placements)
275 if placement.is_replicate()
276 ]
277 spmd_shard_mesh_dim = getattr(self, "_spmd_shard_mesh_dim", None)
278 if getattr(self, "uses_param_shard", False) and spmd_shard_mesh_dim is not None:
279 group_axes = [axis for axis in group_axes if axis != spmd_shard_mesh_dim]
280 if not group_axes:
281 return GroupInfo("fully_shard_unsharded_group_invalid", None, 1)
283 group_dim_names = getattr(self._spmd_mesh, "mesh_dim_names", None)
284 if group_dim_names:
285 try:
286 mesh_axis_names = tuple(group_dim_names[axis] for axis in group_axes)
287 if len(mesh_axis_names) == 1:
288 axis_name = mesh_axis_names[0]
289 process_group = self._spmd_mesh.get_group(axis_name)
290 if process_group is not None:
291 rank_size = self._spmd_mesh.mesh_shape[group_dim_names.index(axis_name)]
292 resolved_group_name = self._resolve_process_group_name(
293 "fully_shard_unsharded_group",
294 process_group,
295 )
296 return _build_group_info_from_process_group(
297 "fully_shard_unsharded_group",
298 process_group,
299 rank_size,
300 resolved_group_name=resolved_group_name,
301 )
303 split_rank_lists = get_split_rank_lists_for_axes(self._spmd_mesh, group_axes)
304 process_group = platform.split_group(split_ranks=split_rank_lists)
305 if process_group is not None:
306 rank_size = 1
307 for axis in group_axes:
308 rank_size *= self._spmd_mesh.mesh_shape[axis]
309 resolved_group_name = self._resolve_process_group_name(
310 "fully_shard_unsharded_group",
311 process_group,
312 )
313 return _build_group_info_from_process_group(
314 "fully_shard_unsharded_group",
315 process_group,
316 rank_size,
317 resolved_group_name=resolved_group_name,
318 )
319 except (
320 AssertionError,
321 AttributeError,
322 KeyError,
323 RuntimeError,
324 TypeError,
325 ValueError,
326 ):
327 pass
329 rank_list = get_rank_list_for_axes(self._spmd_mesh, group_axes)
330 return _build_group_info_from_rank_list("fully_shard_unsharded_group", rank_list)
332 def _normalize_unsharded_grad_to_local(self, grad, *, reduce_partial_dtensor: bool = True):
333 """Normalize a pending gradient to the local tensor expected by fully_shard collectives."""
334 if not isinstance(grad, DTensor):
335 return grad
337 if reduce_partial_dtensor and any(placement.is_partial() for placement in grad.placements):
338 grad = grad.reduce_partial()
340 orig_dtensor_mesh = getattr(self, "_orig_dtensor_mesh", None)
341 orig_dtensor_placements = getattr(self, "_orig_dtensor_placements", None)
342 if (
343 orig_dtensor_mesh is not None
344 and grad.device_mesh.to_hash() != orig_dtensor_mesh.to_hash()
345 ) or (
346 orig_dtensor_placements is not None
347 and tuple(grad.placements) != tuple(orig_dtensor_placements)
348 ):
349 grad = grad.redistribute(orig_dtensor_mesh, orig_dtensor_placements)
350 return grad.to_local()