Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / state.py: 61%
270 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch HSDP cell state"""
16# pylint: disable=protected-access
18from typing import Optional
20import torch
22from hyper_parallel.core.fully_shard.hsdp_state import HSDPState
23from hyper_parallel.core.fully_shard.hsdp_utils import (
24 FullyShardParamMode,
25 _get_param_module_infos,
26 infer_fully_shard_param_mode,
27)
28from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy
29from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2
30from hyper_parallel.platform.torch.fully_shard.pack_utils import build_rs_plan
31from hyper_parallel.platform.torch.fully_shard.param_group import get_comm_ctx, HSDPParamGroup
34def _to_dtype_if_needed(
35 tensor: torch.Tensor, dtype: Optional[torch.dtype]
36) -> torch.Tensor:
37 """Cast tensor to the given dtype if it differs from current dtype.
39 Args:
40 tensor: The input tensor to potentially cast.
41 dtype: Target dtype. If None or same as tensor dtype, no-op.
42 """
43 if dtype is not None and tensor.dtype != dtype:
44 return tensor.to(dtype)
45 return tensor
48class TorchHSDPStateV2(HSDPState):
49 """Torch HSDP cell state"""
50 # DTensor compat parameters in pure-TP mode can accumulate gradients
51 # directly on ``sharded_param.grad`` without ever materializing an
52 # ``_unsharded_param``. Track their async all-reduce work separately from
53 # the standard unsharded-grad queues.
54 pre_direct_all_reduce_grads = []
56 @staticmethod
57 def _get_pending_unsharded_grad(hsdp_param):
58 """Return the pending unsharded gradient tensor for all-reduce-based paths."""
59 if hsdp_param.unsharded_accumulated_grad is not None:
60 return hsdp_param.unsharded_accumulated_grad_data
61 return hsdp_param.unsharded_grad_data
63 @staticmethod
64 def _has_pending_unsharded_grad(hsdp_param):
65 """Whether the parameter currently has a gradient waiting for reduction."""
66 if hsdp_param.unsharded_accumulated_grad is not None:
67 return True
68 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
69 return False
70 return hsdp_param.unsharded_param.grad is not None
72 @staticmethod
73 def _get_local_sharded_grad(hsdp_param):
74 """Return the local gradient tensor currently stored on ``sharded_param``."""
75 grad = hsdp_param.sharded_param.grad
76 if grad is None:
77 return None
78 to_local = getattr(grad, "to_local", None)
79 if callable(to_local):
80 return to_local()
81 return grad
83 def __init__(self, cell, mesh_info, config, platform, device):
84 """
85 Initialize TorchHSDPStateV2.
87 Args:
88 cell (nn.Module): The module whose parameters are managed by this state.
89 mesh_info: Mesh topology for shard/replicate dimensions.
90 config (HSDPConfigV2): HSDP configuration.
91 platform (TorchPlatform): Torch platform abstraction.
92 device (torch.device): Target device.
93 """
94 super().__init__(cell, mesh_info, config, platform, device)
95 self.comm_fusion = config.comm_fusion
96 # Do ReduceScatter/AllReduce for grad
97 self.device = device
98 self.mp_policy = config.mp_policy
99 self.offload_policy = config.offload_policy
100 self.reduce_grads = True
101 # Reshard parameter after backward
102 self.reshard_after_backward = True
103 # Requires AllReduce for grad When HSDP
104 self.requires_all_reduce = True
105 # Default reduce op is decided at the fully_shard-state level:
106 # if any managed parameter is DTensor-backed, use SUM; otherwise AVG.
107 self._user_reduce_op_type = None
108 self.reduce_op_type = self._resolve_default_reduce_op()
109 self._reset_sharded_params = False
110 self._init_param_group()
112 @staticmethod
113 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]:
114 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion."""
115 if not hsdp_param.enable_fsdp_shard:
116 return "non-sharded parameters such as replicate_params are not supported"
117 if hsdp_param.param_mode not in (
118 FullyShardParamMode.LOCAL_PARAM,
119 FullyShardParamMode.DTENSOR_UNIFIED,
120 ):
121 return (
122 "param_mode "
123 f"{hsdp_param.param_mode} is not supported"
124 )
125 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None)
126 if local_shard is None:
127 return "missing local shard tensor for comm_fusion plan validation"
128 plan_world_size = getattr(hsdp_param, "shard_world_size", None)
129 if plan_world_size is None:
130 plan_world_size = getattr(hsdp_param, "shard_size", 1)
131 try:
132 build_rs_plan(hsdp_param, local_shard, plan_world_size)
133 except NotImplementedError as exc:
134 return str(exc)
135 except (AssertionError, ValueError) as exc:
136 return f"cannot build comm_fusion pack plan: {exc}"
137 return None
139 def _init_param_group(self):
140 """Initialize fused parameter group for communication fusion.
142 When ``comm_fusion`` is enabled, creates an ``HSDPParamGroup`` that packs all
143 parameters into a single buffer for fused all-gather and reduce-scatter,
144 replacing the per-parameter communication pattern.
145 """
146 if self.config.comm_fusion:
147 unsupported_param = next(
148 (
149 hsdp_param
150 for hsdp_param in self.hsdp_params
151 if self._comm_fusion_unsupported_reason(hsdp_param) is not None
152 ),
153 None,
154 )
155 if unsupported_param is not None:
156 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>")
157 reason = self._comm_fusion_unsupported_reason(unsupported_param)
158 raise NotImplementedError(
159 f"comm_fusion does not support parameter {param_fqn}: {reason}."
160 )
161 self.param_group = None
162 if self.hsdp_params:
163 # pylint: disable=E1128
164 self.param_group = HSDPParamGroup(
165 self.hsdp_params,
166 self.mesh_info,
167 self.device,
168 self.mp_policy,
169 self.config.comm_fusion_zero_copy,
170 )
172 def _move_states_to_device(self):
173 """move states to device"""
174 for mod in self.modules:
175 for param in mod.parameters():
176 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
177 continue
178 if param.device == self.device or param.device.type == "meta":
179 continue
180 param.data = param.to(self.device)
181 for buffer in mod.buffers():
182 if buffer.device == self.device or buffer.device.type == "meta":
183 continue
184 buffer.data = buffer.to(self.device)
186 def _init_hsdp_params(self):
187 """init hsdp parameters and replicate parameters for cell."""
188 replicate_params = set(self.config.replicate_params or ())
189 # all parameters in the module tree(s), deduplicated
190 ignored_params = set(self.config.ignored_params or ())
191 visited_params = set()
192 filtered_params = []
193 for mod in self.modules:
194 for _, param in mod.named_parameters():
195 if param in ignored_params:
196 continue
197 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
198 continue
199 if param in visited_params:
200 continue
201 visited_params.add(param)
202 filtered_params.append(param)
204 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules))
205 for param, module_info in zip(filtered_params, module_infos):
206 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param])
207 enable_fsdp_shard = param not in replicate_params
208 hsdp_param = TorchHSDPParamV2(param,
209 module_info,
210 self.mesh_info,
211 shard_placement_fn=self.config.shard_placement_fn,
212 mp_policy=self.mp_policy,
213 offload_policy=self.offload_policy,
214 device=self.device,
215 param_mode=param_mode,
216 enable_fsdp_shard=enable_fsdp_shard,
217 )
218 if param in replicate_params:
219 self.replicate_params.append(hsdp_param)
220 else:
221 self.hsdp_params.append(hsdp_param)
222 if hsdp_param.is_sharded:
223 self.sharded_hsdp_params.append(hsdp_param)
225 def _init_mp_dtypes(self):
226 """init mp dtypes for hsdp parameters and replicate parameters"""
227 for hsdp_param in self.hsdp_params:
228 hsdp_param.init_dtype_attrs(self.mp_policy)
229 for replicate_param in self.replicate_params:
230 replicate_param.init_dtype_attrs(self.mp_policy)
231 trainable_params: list[TorchHSDPParamV2] = [
232 p for p in self._iter_managed_params() if p.sharded_param.requires_grad
233 ]
234 orig_dtypes = {p.orig_dtype for p in trainable_params}
235 reduce_dtypes = {p.reduce_dtype for p in trainable_params}
236 if len(trainable_params) > 0 and len(orig_dtypes) != 1:
237 raise AssertionError(
238 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}"
239 )
240 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
241 if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
242 raise AssertionError(
243 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}"
244 )
245 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
247 def _validate_cpu_offload_params(self):
248 """Validate that all parameters are on CPU when CPU offload policy is enabled."""
249 if not isinstance(self.offload_policy, CPUOffloadPolicy):
250 return
251 hsdp_params_not_on_cpu = [
252 hsdp_param
253 for hsdp_param in self._iter_managed_params()
254 if hsdp_param.sharded_param.device.type != "cpu"
255 ]
256 if hsdp_params_not_on_cpu:
257 raise RuntimeError(
258 "HSDP parameters should be materialized on CPU when enabling CPU offloading. "
259 'For example, load a CPU state dict or call module.to_empty(device="cpu"). '
260 "Found following parameters on non-CPU device: "
261 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n"
262 )
264 def lazy_init(self):
265 if not self._reset_sharded_params:
266 for hsdp_param in self.hsdp_params:
267 if hsdp_param.is_sharded:
268 hsdp_param.reset_sharded_param()
269 self._reset_sharded_params = True
270 self._validate_no_meta_params()
271 self._validate_cpu_offload_params()
272 self._init_mp_dtypes()
274 def _validate_no_meta_params(self):
275 param_names_on_meta = [
276 hsdp_param._param_fqn
277 for hsdp_param in self._iter_managed_params()
278 if hsdp_param.sharded_param.device.type == "meta"
279 ]
280 if param_names_on_meta:
281 raise RuntimeError(
282 "HSDP parameters should be materialized from meta device before training, "
283 f"but the following were still on meta device: {param_names_on_meta}\n"
284 "For example, call module.to_empty(device) to materialize to device and "
285 "call module.reset_parameters() on each module to initialize values."
286 )
288 def post_backward_for_comm_fusion(self):
289 """post_backward_for_comm_fusion."""
290 # Replicate-only params still use the non-fused compat all-reduce path.
291 # Drain any pending side-path reductions before advancing the fused
292 # param-group pipeline for sharded params.
293 self.reduce_params()
294 # Fused gradient reduction path: first apply any pending async reduction
295 # from the previous module's backward (pipelined overlap), then issue
296 # this module's fused reduce-scatter (+ all-reduce for HSDP).
297 comm_ctx = get_comm_ctx()
298 # Phase 2: apply grads for the param group whose all_reduce is done
299 if comm_ctx.all_reduce_param_group is not None:
300 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad()
301 comm_ctx.all_reduce_param_group = None
302 # Phase 1: wait reduce_scatter, issue async all_reduce for previous layer
303 if comm_ctx.pre_param_group is not None:
304 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce()
305 comm_ctx.pre_param_group = None
306 if self.param_group is not None:
307 self.param_group.foreach_reduce(
308 reduce_scatter_reduce_op=self.reduce_op_type
309 )
310 for hsdp_param in self.replicate_params:
311 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
312 continue
313 if not hsdp_param.sharded_param.requires_grad:
314 continue
315 if not self._has_pending_unsharded_grad(hsdp_param):
316 continue
317 reduce_op = self._resolve_reduce_op(hsdp_param)
318 self._queue_compat_all_reduce(hsdp_param, reduce_op)
320 def _resolve_default_reduce_op(self):
321 """Resolve the default reduce op for the whole fully_shard state."""
322 for hsdp_param in self._iter_managed_params():
323 if hsdp_param.param_mode in (
324 FullyShardParamMode.DTENSOR_COMPAT,
325 FullyShardParamMode.DTENSOR_UNIFIED,
326 ):
327 return torch.distributed.ReduceOp.SUM
328 return torch.distributed.ReduceOp.AVG
330 def _resolve_reduce_op(self, hsdp_param=None):
331 """Resolve the gradient reduction op for the current fully_shard state."""
332 if self._user_reduce_op_type is not None:
333 return self._user_reduce_op_type
334 return self.reduce_op_type
336 def _should_run_all_reduce(self, hsdp_param) -> bool:
337 """Whether the current parameter should issue an all-reduce in this backward pass."""
338 return self.requires_all_reduce and hsdp_param.dp_size > 1
340 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param, reduce_op):
341 """Queue the standard FSDP/HSDP reduction path."""
342 hsdp_param.reduce_scatter_grad(
343 dtype=self._reduce_dtype,
344 reduce_op=reduce_op,
345 )
346 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype))
347 if not self._should_run_all_reduce(hsdp_param):
348 return
349 reduced_grad = hsdp_param.reduce_scatter_output()
350 if (
351 HSDPState.pre_reduce_scatter_params
352 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param
353 ):
354 HSDPState.pre_reduce_scatter_params.pop()
355 hsdp_param.all_reduce_grad(
356 grad=reduced_grad,
357 dtype=self._reduce_dtype,
358 reduce_op=reduce_op,
359 )
360 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype))
362 def _queue_compat_all_reduce(self, hsdp_param, reduce_op):
363 """Queue the compatibility all-reduce path without FSDP sharding."""
364 if not self._should_run_all_reduce(hsdp_param):
365 return
366 hsdp_param.all_reduce_grad(
367 grad=self._get_pending_unsharded_grad(hsdp_param),
368 dtype=self._reduce_dtype,
369 reduce_op=reduce_op,
370 )
371 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype))
373 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool:
374 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly."""
375 return (
376 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT
377 and hsdp_param.enable_fsdp_shard
378 and not hsdp_param.is_sharded
379 and hsdp_param.shard_size == 1
380 and hsdp_param.sharded_param.requires_grad
381 and self._should_run_all_reduce(hsdp_param)
382 and self._get_local_sharded_grad(hsdp_param) is not None
383 )
385 def _queue_direct_compat_all_reduce(self, hsdp_param, reduce_op):
386 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``."""
387 grad = self._get_local_sharded_grad(hsdp_param)
388 if grad is None:
389 return
390 reduced_grad = grad
391 if self._reduce_dtype is not None and reduced_grad.dtype != self._reduce_dtype:
392 reduced_grad = reduced_grad.to(self._reduce_dtype)
393 handle = None
394 if hsdp_param.unsharded_group_info.group is not None and hsdp_param.dp_size > 1:
395 handle = torch.distributed.all_reduce(
396 reduced_grad,
397 op=reduce_op,
398 group=hsdp_param.unsharded_group_info.group,
399 async_op=True,
400 )
401 TorchHSDPStateV2.pre_direct_all_reduce_grads.append((handle, reduced_grad, grad))
403 def post_backward(self, *unused): # pylint: disable=unused-argument
404 """Reduce gradients and reshard parameters after backward."""
405 for hsdp_param in self._iter_managed_params():
406 hsdp_param.accumulate_unsharded_grad_if_needed()
407 if not self.reduce_grads:
408 if self.reshard_after_backward:
409 self.shard()
410 for hsdp_param in self._iter_managed_params():
411 hsdp_param.to_accumulated_grad_if_needed()
412 return
413 if not self.comm_fusion:
414 self.reduce_params()
415 for hsdp_param in self._iter_managed_params():
416 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
417 if self._can_direct_all_reduce_compat_grad(hsdp_param):
418 reduce_op = self._resolve_reduce_op(hsdp_param)
419 self._queue_direct_compat_all_reduce(hsdp_param, reduce_op)
420 continue
421 # Frozen parameters produce no gradient, so there is nothing to reduce.
422 if not hsdp_param.sharded_param.requires_grad:
423 continue
424 if not self._has_pending_unsharded_grad(hsdp_param):
425 continue
426 reduce_op = self._resolve_reduce_op(hsdp_param)
427 if hsdp_param.shard_size > 1:
428 self._queue_reduce_scatter_then_all_reduce(hsdp_param, reduce_op)
429 elif self._should_run_all_reduce(hsdp_param):
430 self._queue_compat_all_reduce(hsdp_param, reduce_op)
431 else:
432 self.post_backward_for_comm_fusion()
433 if self.reshard_after_backward:
434 self.shard()
436 def reduce_params(self):
437 """Apply reduced gradients from pre-staged HSDP parameters to sharded parameters.
439 This function processes two lists of pre-queued HSDP parameters (`pre_reduce_scatter_params`
440 and `pre_all_reduce_params`), retrieves the reduced gradients from asynchronous
441 reduce-scatter/all-reduce operations, clears cached communication outputs, and applies
442 the reduced gradients to the corresponding sharded parameters (including reshaping,
443 dtype conversion, optional CPU offloading, and gradient accumulation/assignment).
445 Note:
446 - Parameters are processed in **FIFO (First-In-First-Out)** order (via `pop(0)`), ensuring
447 gradient application order matches the order of gradient reduction operations.
448 - After retrieving the reduced gradient, the cached communication output (reduce_scatter_output
449 or all_reduce_output) is cleared to free memory and avoid stale data.
450 - Gradient application logic (in `apply_reduced_grad`) includes:
451 1. Reshaping the flat reduced gradient to match the local shard shape
452 2. Optional dtype conversion to `param_type`
453 3. Optional CPU offloading (per the HSDP parameter's offload policy)
454 4. Assigning or accumulating the gradient to `sharded_param.grad`
455 """
456 need_synchronize = False
457 while HSDPState.pre_reduce_scatter_params:
458 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_reduce_scatter_params.pop(0)
459 reduced_grad = pre_hsdp_param.reduce_scatter_output()
460 pre_hsdp_param.clear_reduce_scatter_output()
461 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize
463 while HSDPState.pre_all_reduce_params:
464 pre_hsdp_param, pre_orig_dtype = HSDPState.pre_all_reduce_params.pop(0)
465 reduced_grad = pre_hsdp_param.all_reduce_output()
466 pre_hsdp_param.clear_all_reduce_output()
467 need_synchronize = pre_hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) or need_synchronize
469 while TorchHSDPStateV2.pre_direct_all_reduce_grads:
470 handle, reduced_grad, target_grad = TorchHSDPStateV2.pre_direct_all_reduce_grads.pop(0)
471 if handle is not None:
472 handle.wait()
473 if reduced_grad is not target_grad:
474 if reduced_grad.dtype != target_grad.dtype:
475 reduced_grad = reduced_grad.to(target_grad.dtype)
476 target_grad.copy_(reduced_grad)
477 if need_synchronize:
478 if self.device.type == "npu":
479 torch.npu.current_stream().synchronize()
480 elif self.device.type == "cuda":
481 torch.cuda.current_stream().synchronize()
482 else:
483 raise NotImplementedError(
484 f"Unsupported device type {self.device.type} for synchronization after CPU offload."
485 )
487 def set_requires_grad_sync(self, requires_grad_sync):
488 """set requires grad sync flag to control gradient sync."""
489 self.reduce_grads = requires_grad_sync
491 def set_reduce_op_type(self, reduce_op_type: str):
492 """set reduce op type for gradient reduction."""
493 fsdp_support_reduce_op = {
494 "sum": torch.distributed.ReduceOp.SUM,
495 "avg": torch.distributed.ReduceOp.AVG,
496 }
497 if reduce_op_type not in fsdp_support_reduce_op:
498 raise ValueError(
499 f"Unsupported reduce op type {reduce_op_type}, "
500 f"supported types are {list(fsdp_support_reduce_op.keys())}"
501 )
502 reduce_op: str = reduce_op_type.lower().strip()
503 self._user_reduce_op_type = fsdp_support_reduce_op[reduce_op]
504 self.reduce_op_type = self._user_reduce_op_type