Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_utils.py: 79%
160 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP optimizer shared level"""
16from dataclasses import dataclass, field
17from enum import auto, Enum
18from typing import Any, List, Optional, Sequence
20import numpy as np
22from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
23from hyper_parallel.core.dtensor.dtensor import DTensor
24from hyper_parallel.platform import get_platform
25from hyper_parallel.platform.platform import PlatformType
27platform = get_platform()
30class HSDPConfigV2:
31 """HSDPConfigV2 inspect by torch fully_shard"""
33 def __init__(self,
34 mesh,
35 reshard_after_forward,
36 shard_placement_fn,
37 mp_policy,
38 offload_policy,
39 ignored_params=None,
40 replicate_params=None,
41 comm_fusion=False,
42 comm_fusion_zero_copy=False,
43 ):
44 self.mesh = mesh
45 self.reshard_after_forward = reshard_after_forward
46 self.shard_placement_fn = shard_placement_fn
47 self.mp_policy = mp_policy
48 self.offload_policy = offload_policy
49 self.ignored_params = ignored_params
50 self.replicate_params = replicate_params
51 self.reduce_dtype = self.mp_policy.reduce_dtype if self.mp_policy else None
52 self.comm_fusion = comm_fusion
53 self.comm_fusion_zero_copy = comm_fusion_zero_copy
56class ShardedState(Enum):
57 """
58 Parameter shard state
59 """
60 SHARDED = auto()
61 UNSHARDED = auto()
64class FullyShardParamMode(Enum):
65 """
66 Internal fully_shard execution modes derived from the original parameter layout.
68 LOCAL_PARAM:
69 The parameter is a regular local tensor parameter and fully_shard owns the
70 full data-parallel sharding behaviour.
71 DTENSOR_COMPAT:
72 The parameter already carries a DTensor layout and fully_shard is only used
73 as the compatibility wrapper without adding an extra FSDP shard dimension.
74 DTENSOR_UNIFIED:
75 The parameter already carries a DTensor layout and fully_shard additionally
76 contributes a data-parallel/FSDP mesh that must be unified with the
77 existing distributed layout.
78 """
80 LOCAL_PARAM = auto()
81 DTENSOR_COMPAT = auto()
82 DTENSOR_UNIFIED = auto()
85@dataclass
86class GroupInfo:
87 """Communication group metadata used by fully_shard."""
89 group_name: str
90 group: Any
91 rank_size: int
94class FSDPSchedulerState(Enum):
95 """
96 Scheduler state:
97 - PRE_FORWARD:
98 already run hook before forward.
99 - FORWARD:
100 already run hook after forward.
101 - PRE_BACKWARD:
102 already run hook before backward.
103 - BACKWARD:
104 already run hook after backward.
105 """
106 PRE_FORWARD = auto()
107 FORWARD = auto()
108 PRE_BACKWARD = auto()
109 BACKWARD = auto()
112@dataclass
113class ParamModuleInfo:
114 """
115 Tracks parameter ownership and supports shared weights in HSDP.
117 This dataclass maintains the mapping between a parameter and its module(s),
118 enabling parameter swapping during sharding/unsharding transitions. Shared
119 weights are parameters referenced by multiple modules (e.g., tied embeddings).
121 This class tracks all references to ensure proper parameter replacement during
122 sharding/unsharding operations.
124 Attributes:
125 module: The module that owns this parameter.
126 param_name: Attribute name of the parameter in the module (e.g., "weight").
127 shared_modules: List of other modules sharing this same parameter object.
128 shared_param_names: Corresponding parameter names in shared_modules (aligned by index).
129 """
130 module: platform.Module
131 param_name: str
132 shared_modules: List[platform.Module] = field(default_factory=list)
133 shared_param_names: List[str] = field(default_factory=list)
136def _named_parameters_with_duplicates(
137 module: platform.Module, **kwargs: Any
138) -> list[tuple[str, platform.Parameter]]:
139 """
140 This API is required as some modules overwrite `named_parameters()` but do not support
141 `remove_duplicate`.
142 """
143 if "remove_duplicate" in kwargs:
144 raise AssertionError(
145 "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
146 )
148 def get_named_parameters(module, **kwargs):
149 if platform.platform_type == PlatformType.PYTORCH:
150 return module.named_parameters(**kwargs)
151 return module.parameters_and_names(expand=False)
152 kwargs["remove_duplicate"] = False
153 try:
154 ret = list(get_named_parameters(module, **kwargs))
155 except AssertionError:
156 kwargs.pop("remove_duplicate")
157 ret = list(get_named_parameters(module, **kwargs))
158 return ret
161def _get_param_module_infos(
162 params: list[platform.Parameter], modules: tuple[platform.Module, ...]
163) -> list['ParamModuleInfo']:
164 """
165 Shared parameter: lin1.weight = lin2.weight
166 Shared module: mlp.lin1 = mlp.lin2
167 We do not remove duplicates when traversing both modules and parameters to
168 find shared modules' parameters and shared parameters within a module.
169 """
170 params_set = set(params)
171 param_to_module_info: dict[platform.Parameter, ParamModuleInfo] = {}
173 def get_named_modules(module):
174 if platform.platform_type == PlatformType.PYTORCH:
175 return module.named_modules(remove_duplicate=False)
176 return module.cells_and_names()
178 for module in modules:
179 for _, submodule in get_named_modules(module):
180 for param_name, param in _named_parameters_with_duplicates(
181 submodule, recurse=False
182 ):
183 if param in params_set:
184 if param not in param_to_module_info:
185 param_to_module_info[param] = ParamModuleInfo(
186 submodule, param_name
187 )
188 else:
189 param_to_module_info[param].shared_modules.append(submodule)
190 param_to_module_info[param].shared_param_names.append(
191 param_name
192 )
193 if len(param_to_module_info) != len(params):
194 raise AssertionError(f"Some parameters are not in the module tree of {modules}")
195 return [param_to_module_info[param] for param in params]
198def get_managed_modules_parameters(
199 modules: Sequence[platform.Module],
200 ignored_params: Optional[Sequence[platform.Parameter]] = None,
201) -> list[platform.Parameter]:
202 """Collect deduplicated parameters from ``modules`` while skipping ignored params.
204 Parameters that were already initialized by an inner ``fully_shard`` instance
205 are intentionally excluded so nested ``fully_shard(mesh=None)`` resolves mesh
206 mode from the parameters that the current wrapper will actually manage.
207 """
208 params: list[platform.Parameter] = []
209 ignored_params_set = set(ignored_params or ())
210 visited_params: set[platform.Parameter] = set()
211 for mod in modules:
212 for _, param in platform.parameters_dict(mod):
213 if param in ignored_params_set or param in visited_params:
214 continue
215 if getattr(param, "_hsdp_param_initialized", False):
216 continue
217 visited_params.add(param)
218 params.append(param)
219 return params
222def infer_fully_shard_param_mode(
223 mesh: Optional[DeviceMesh],
224 params: Optional[Sequence[Any]] = None,
225) -> FullyShardParamMode:
226 """
227 Infer the internal fully_shard execution mode from parameter layout and mesh.
229 The mode is intentionally phrased around whether parameters already carry a
230 distributed layout instead of assuming the layout came from TP only. DTensor
231 parameters may originate from TP, EP, or other distributed sharding paths.
232 """
233 has_dtensor_param = any(is_dtensor_managed_param(param) for param in params or ())
234 if not has_dtensor_param:
235 return FullyShardParamMode.LOCAL_PARAM
236 if mesh is None:
237 return FullyShardParamMode.DTENSOR_COMPAT
238 return FullyShardParamMode.DTENSOR_UNIFIED
241def unwrap_dtensor_param(param: Any) -> Optional[DTensor]:
242 """Return the DTensor payload carried by ``param`` if one exists."""
243 if isinstance(param, DTensor):
244 return param
245 param_data = getattr(param, "data", None)
246 if isinstance(param_data, DTensor):
247 return param_data
248 if all(hasattr(param, attr) for attr in ("_device_mesh", "_placements", "_local_tensor")):
249 return param
250 return None
253def is_dtensor_managed_param(param: Any) -> bool:
254 """Return whether a parameter already carries DTensor layout metadata."""
255 return unwrap_dtensor_param(param) is not None
258def get_dtensor_managed_mesh(param: Any) -> Optional[DeviceMesh]:
259 """Return the DTensor mesh carried by ``param`` if one exists."""
260 payload = unwrap_dtensor_param(param)
261 if payload is None:
262 return None
263 return getattr(payload, "device_mesh", getattr(payload, "_device_mesh", None))
266def get_rank_list_for_axes(
267 mesh: DeviceMesh,
268 axes: Sequence[int],
269 rank: Optional[int] = None,
270) -> list[int]:
271 """Return ranks that vary along ``axes`` and keep all other coordinates fixed."""
272 if rank is None:
273 rank = mesh.rank
274 if rank not in mesh.rank_list:
275 raise ValueError(f"Rank {rank} not found in mesh rank list {mesh.rank_list}.")
277 normalized_axes = tuple(sorted(set(axes)))
278 if len(normalized_axes) == 0:
279 return [rank]
281 mesh_tensor = np.array(mesh.rank_list).reshape(mesh.mesh_shape)
282 rank_index = mesh.rank_list.index(rank)
283 coord = [0] * len(mesh.mesh_shape)
284 temp = rank_index
285 for i in range(len(mesh.mesh_shape) - 1, -1, -1):
286 coord[i] = temp % mesh.mesh_shape[i]
287 temp //= mesh.mesh_shape[i]
288 mesh_slice = []
289 for axis, axis_coord in enumerate(coord):
290 mesh_slice.append(slice(None) if axis in normalized_axes else axis_coord)
291 selected = mesh_tensor[tuple(mesh_slice)]
292 return [int(item) for item in np.array(selected).reshape(-1).tolist()]
295def get_split_rank_lists_for_axes(
296 mesh: DeviceMesh,
297 axes: Sequence[int],
298) -> list[list[int]]:
299 """Return all rank lists induced by varying ``axes`` and fixing the complementary axes."""
300 normalized_axes = tuple(sorted(set(axes)))
301 if len(normalized_axes) == 0:
302 return [[int(rank) for rank in mesh.rank_list]]
304 mesh_tensor = np.array(mesh.rank_list).reshape(mesh.mesh_shape)
305 complementary_axes = tuple(
306 axis for axis in range(len(mesh.mesh_shape)) if axis not in normalized_axes
307 )
308 if len(complementary_axes) == 0:
309 return [[int(item) for item in np.array(mesh_tensor).reshape(-1).tolist()]]
311 complementary_shape = tuple(mesh.mesh_shape[axis] for axis in complementary_axes)
312 split_rank_lists: list[list[int]] = []
313 for complementary_coord in np.ndindex(*complementary_shape):
314 mesh_slice = []
315 coord_idx = 0
316 for axis in range(len(mesh.mesh_shape)):
317 if axis in normalized_axes:
318 mesh_slice.append(slice(None))
319 else:
320 mesh_slice.append(complementary_coord[coord_idx])
321 coord_idx += 1
322 selected = mesh_tensor[tuple(mesh_slice)]
323 split_rank_lists.append([int(item) for item in np.array(selected).reshape(-1).tolist()])
324 return split_rank_lists
326def get_hsdp_state(module):
327 """Return the HSDPState for a fully_shard-managed module, or None."""
328 from hyper_parallel.core.fully_shard.api import HSDPModule # pylint: disable=C0415
329 if isinstance(module, HSDPModule):
330 scheduler = getattr(module, "hsdp_scheduler", None)
331 if scheduler is not None:
332 return scheduler.hsdp_state
333 return None