Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / state.py: 60%
311 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore HSDP cell state"""
16from typing import Optional
17import mindspore as ms
18from mindspore import ops
19import mindspore.mint.distributed as dist
20from hyper_parallel.core.fully_shard.hsdp_state import HSDPState
21from hyper_parallel.core.fully_shard.hsdp_utils import (
22 _get_param_module_infos,
23 FullyShardParamMode,
24 infer_fully_shard_param_mode,
25)
26from hyper_parallel.platform.mindspore.fully_shard.pack_utils import build_rs_plan
27from hyper_parallel.platform.mindspore.fully_shard.param import MindSporeHSDPParamV2
28from hyper_parallel.platform.mindspore.fully_shard._version_utils import copy_without_bumping_version
29from hyper_parallel.platform.mindspore.fully_shard.param_group import HSDPParamGroup, get_comm_ctx
30from hyper_parallel.platform.mindspore.utils import normalize_runtime_device
31from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy
34def _to_dtype_if_needed(
35 tensor: ms.Tensor, dtype: Optional[ms.Type]
36) -> ms.Tensor:
37 """Cast tensor to the given dtype if it differs from current dtype.
39 Args:
40 tensor: The input tensor to potentially cast.
41 dtype: Target dtype. If None or same as tensor dtype, no-op.
42 """
43 if dtype is not None and tensor.dtype != dtype:
44 return tensor.to(dtype)
45 return tensor
48class MindSporeHSDPStateV2(HSDPState):
49 """MindSpore HSDP cell state"""
50 # DTensor compat parameters in pure-TP mode can accumulate gradients
51 # directly on ``sharded_param.grad`` without materializing an
52 # ``_unsharded_param``. Track those async all-reduces separately from the
53 # standard unsharded-gradient queues.
54 pre_direct_all_reduce_grads = []
56 @staticmethod
57 def _get_pending_unsharded_grad(hsdp_param):
58 """Return the pending unsharded gradient tensor for reduction paths."""
59 if hsdp_param.unsharded_accumulated_grad is not None:
60 return hsdp_param.unsharded_accumulated_grad_data
61 return hsdp_param.unsharded_grad_data
63 @staticmethod
64 def _has_pending_unsharded_grad(hsdp_param):
65 """Whether the parameter currently has a gradient waiting for reduction."""
66 if hsdp_param.unsharded_accumulated_grad is not None:
67 return True
68 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
69 return False
70 return hsdp_param.unsharded_param.grad is not None
72 @staticmethod
73 def _get_local_sharded_grad(hsdp_param):
74 """Return the local gradient tensor currently stored on ``sharded_param``."""
75 grad = hsdp_param.sharded_param.grad
76 if grad is None:
77 return None
78 to_local = getattr(grad, "to_local", None)
79 if callable(to_local):
80 return to_local()
81 return grad
83 @staticmethod
84 def _synchronize_current_stream_if_needed(need_synchronize: bool) -> None:
85 """Synchronize the current device stream after non-blocking CPU offload."""
86 if not need_synchronize:
87 return
88 ms.runtime.current_stream().synchronize()
90 def __init__(self, cell, mesh_info, config, platform, device=None):
91 super().__init__(cell, mesh_info, config, platform, device)
92 self.comm_fusion = config.comm_fusion
93 # Do ReduceScatter/AllReduce for grad
94 self.mp_policy = config.mp_policy
95 self.offload_policy = config.offload_policy
96 self.reduce_grads = True
97 # Reshard parameter after backward
98 self.reshard_after_backward = True
99 # Requires AllReduce for grad When HSDP
100 self.requires_all_reduce = True
101 # Keep historical AVG behavior for local parameters while DTensor-aware
102 # paths default to SUM semantics without extra division.
103 self.reduce_op_type = ops.ReduceOp.SUM
104 self._need_div = not any(
105 getattr(param, "param_mode", FullyShardParamMode.LOCAL_PARAM)
106 != FullyShardParamMode.LOCAL_PARAM
107 for param in self._iter_managed_params()
108 )
109 self._ignored_allreduce_works = []
110 self._reset_sharded_params = False
111 self._init_param_group()
113 def _iter_managed_params(self):
114 """Return all fully_shard-managed parameters, including replicate_params."""
115 return [*self.hsdp_params, *self.replicate_params]
117 @staticmethod
118 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]:
119 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion."""
120 if not hsdp_param.enable_fsdp_shard:
121 return "non-sharded parameters such as replicate_params are not supported"
122 if hsdp_param.param_mode not in (
123 FullyShardParamMode.LOCAL_PARAM,
124 FullyShardParamMode.DTENSOR_UNIFIED,
125 ):
126 return f"param_mode {hsdp_param.param_mode} is not supported"
127 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None)
128 if local_shard is None:
129 return "missing local shard tensor for comm_fusion plan validation"
130 plan_world_size = getattr(hsdp_param, "shard_world_size", None)
131 if plan_world_size is None:
132 plan_world_size = getattr(hsdp_param, "shard_size", 1)
133 try:
134 build_rs_plan(hsdp_param, local_shard, plan_world_size)
135 except NotImplementedError as exc:
136 return str(exc)
137 except (AssertionError, ValueError) as exc:
138 return f"cannot build comm_fusion pack plan: {exc}"
139 return None
141 def _init_param_group(self):
142 """Initialize fused parameter group when comm_fusion is enabled."""
143 if self.config.comm_fusion:
144 unsupported_param = next(
145 (
146 hsdp_param
147 for hsdp_param in self.hsdp_params
148 if self._comm_fusion_unsupported_reason(hsdp_param) is not None
149 ),
150 None,
151 )
152 if unsupported_param is not None:
153 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>")
154 reason = self._comm_fusion_unsupported_reason(unsupported_param)
155 raise NotImplementedError(
156 f"comm_fusion does not support parameter {param_fqn}: {reason}."
157 )
158 self.param_group = None
159 if self.hsdp_params:
160 self.param_group = HSDPParamGroup(
161 self.hsdp_params,
162 self.mesh_info,
163 self.device,
164 self.mp_policy,
165 self.config.comm_fusion_zero_copy,
166 )
168 def zero_grad(self):
169 """zero grad"""
170 for hsdp_param in self.hsdp_params:
171 hsdp_param.zero_grad()
172 for hsdp_param in self.replicate_params:
173 hsdp_param.zero_grad()
175 @staticmethod
176 def _div_if_needed(x, divisor, need_div: bool):
177 """Apply gradient averaging only when the caller-provided policy requires it.
179 ``need_div`` may come from the current state or from metadata captured when
180 async reduce work was queued, so this helper is safe for both immediate and
181 deferred gradient materialization paths.
182 """
183 if not need_div:
184 return
185 if divisor == 1:
186 return
187 x.div_(divisor)
189 def _move_states_to_device(self):
190 """move states to device"""
191 for mod in self.modules:
192 for param in mod.get_parameters():
193 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
194 continue
195 param_device = normalize_runtime_device(param.device)
196 if param_device in (self.device, "meta"):
197 continue
198 param.data = param.to(self.device)
199 for buffer in mod.buffers():
200 if buffer.device in (self.device, "meta"):
201 continue
202 buffer.data = buffer.to(self.device)
204 def _init_hsdp_params(self):
205 """init hsdp parameters for cell and replicate parameters for cell."""
206 # all parameters in the module tree(s), deduplicated
207 visited_params = set()
208 replicate_params = set(self.config.replicate_params or ())
209 ignored_params = set(self.config.ignored_params or ())
210 filtered_params = []
211 for mod in self.modules:
212 for _, param in mod.parameters_and_names():
213 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
214 continue
215 if param in ignored_params:
216 continue
217 if param in visited_params:
218 continue
219 visited_params.add(param)
220 filtered_params.append(param)
222 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules))
223 for param, module_info in zip(filtered_params, module_infos):
224 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param])
225 enable_fsdp_shard = param not in replicate_params
226 hsdp_param = MindSporeHSDPParamV2(
227 param,
228 module_info,
229 self.mesh_info,
230 shard_placement_fn=self.config.shard_placement_fn,
231 mp_policy=self.mp_policy,
232 offload_policy=self.offload_policy,
233 device=self.device,
234 param_mode=param_mode,
235 enable_fsdp_shard=enable_fsdp_shard,
236 )
237 if param in replicate_params:
238 self.replicate_params.append(hsdp_param)
239 else:
240 self.hsdp_params.append(hsdp_param)
241 if hsdp_param.is_sharded:
242 self.sharded_hsdp_params.append(hsdp_param)
244 def _init_mp_dtypes(self):
245 """init mp dtypes for hsdp parameters and replicate parameters"""
246 for hsdp_param in self.hsdp_params:
247 hsdp_param.init_dtype_attrs(self.mp_policy)
248 for replicate_param in self.replicate_params:
249 replicate_param.init_dtype_attrs(self.mp_policy)
250 trainable_params: list[MindSporeHSDPParamV2] = [
251 p for p in self._iter_managed_params() if p.sharded_param.requires_grad
252 ]
253 orig_dtypes = {p.orig_dtype for p in trainable_params}
254 reduce_dtypes = {p.reduce_dtype for p in trainable_params}
255 if len(trainable_params) > 0 and len(orig_dtypes) != 1:
256 raise AssertionError(
257 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}"
258 )
259 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
260 if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
261 raise AssertionError(
262 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}"
263 )
264 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
266 def lazy_init(self):
267 """Refresh parameter views and validate runtime state before first execution."""
268 if not self._reset_sharded_params:
269 for hsdp_param in self.hsdp_params:
270 if hsdp_param.is_sharded:
271 hsdp_param.reset_sharded_param()
272 self._reset_sharded_params = True
273 self._validate_no_meta_params()
274 self._validate_cpu_offload_params()
275 self._init_mp_dtypes()
277 def _validate_cpu_offload_params(self):
278 """Validate that all parameters are on CPU when CPU offload policy is enabled."""
279 if not isinstance(self.offload_policy, CPUOffloadPolicy):
280 return
281 hsdp_params_not_on_cpu = [
282 hsdp_param
283 for hsdp_param in self._iter_managed_params()
284 if not str(hsdp_param.sharded_param.device).lower().startswith("cpu")
285 ]
286 if hsdp_params_not_on_cpu:
287 raise RuntimeError(
288 "HSDP parameters should be materialized on CPU when enabling CPU offloading. "
289 "For example, load a CPU state dict before training. "
290 "Found following parameters on non-CPU device: "
291 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n"
292 )
294 def _validate_no_meta_params(self):
295 """Validate that all parameters have been materialized from meta device."""
296 param_names_on_meta = [
297 hsdp_param._param_fqn
298 for hsdp_param in self._iter_managed_params()
299 if hsdp_param.sharded_param.device == "meta"
300 ]
301 if param_names_on_meta:
302 raise RuntimeError(
303 "HSDP parameters should be materialized from meta device before training, "
304 f"but the following were still on meta device: {param_names_on_meta}\n"
305 "For example, initialize the module weights on a real device before running training."
306 )
308 def _allreduce_replicate_params(self, async_op=True) -> None:
309 """
310 DDP-style all-reduce for parameters in config.replicate_params.
312 Use the parameter's layout-driven unsharded group so DTensor-aware
313 compatibility and unified modes reduce over the correct axes.
314 """
315 for param in self.replicate_params:
316 if not hasattr(param, "_unsharded_param") or param.unsharded_param is None:
317 continue
318 if (
319 param.unsharded_accumulated_grad is None
320 and param.unsharded_param.grad is None
321 ):
322 continue
324 reduced_grad = param.unsharded_accumulated_grad_data
325 if reduced_grad is None:
326 reduced_grad = param.unsharded_grad_data
327 reduced_grad = _to_dtype_if_needed(reduced_grad, self._reduce_dtype)
328 reduce_group_info = getattr(param, "unsharded_group_info", None)
329 reduce_group = reduce_group_info.group if reduce_group_info is not None else None
330 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1
332 if reduce_group is not None and reduce_group_size > 1:
333 param.all_reduce_handle = dist.all_reduce(
334 reduced_grad, group=reduce_group, op=self.reduce_op_type, async_op=async_op
335 )
336 self._ignored_allreduce_works.append((param, reduced_grad, reduce_group_size))
338 def _finish_ignored_allreduce(self) -> None:
339 """
340 Wait for async all-reduce of replicate_params and materialize param.grad.
342 For each pending work, this:
343 Waits on all associated handles to complete;
344 Casts reduced_grad back to _orig_dtype if needed;
345 Assigns the final tensor to param.grad.
346 """
347 if not self._ignored_allreduce_works:
348 return
350 need_synchronize = False
351 for param, reduced_grad, reduce_group_size in self._ignored_allreduce_works:
352 if param.all_reduce_handle:
353 param.all_reduce_handle.wait()
354 self._div_if_needed(reduced_grad, reduce_group_size, self._need_div)
355 need_synchronize = (
356 param.apply_reduced_grad(reduced_grad, self._orig_dtype)
357 or need_synchronize
358 )
360 self._synchronize_current_stream_if_needed(need_synchronize)
361 self._ignored_allreduce_works.clear()
363 def reduce_params(self):
364 """Drain pending sharded parameter reductions and materialize sharded grads."""
365 need_synchronize = False
366 while HSDPState.pre_reduce_scatter_params:
367 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_reduce_scatter_params.pop(0)
368 reduced_grad = hsdp_param.reduce_scatter_output()
369 self._div_if_needed(reduced_grad, hsdp_param.shard_world_size, need_div)
370 hsdp_param.clear_reduce_scatter_output()
371 need_synchronize = (
372 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype)
373 or need_synchronize
374 )
376 while HSDPState.pre_all_reduce_params:
377 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_all_reduce_params.pop(0)
378 reduced_grad = hsdp_param.all_reduce_output()
379 self._div_if_needed(reduced_grad, hsdp_param.replicate_world_size, need_div)
380 hsdp_param.clear_all_reduce_output()
381 need_synchronize = (
382 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype)
383 or need_synchronize
384 )
385 while MindSporeHSDPStateV2.pre_direct_all_reduce_grads:
386 handle, reduced_grad, target_grad, reduce_group_size, need_div = (
387 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.pop(0)
388 )
389 if handle is not None:
390 handle.wait()
391 self._div_if_needed(reduced_grad, reduce_group_size, need_div)
392 if reduced_grad is not target_grad:
393 if reduced_grad.dtype != target_grad.dtype:
394 reduced_grad = reduced_grad.to(target_grad.dtype)
395 copy_without_bumping_version(target_grad, reduced_grad)
396 self._synchronize_current_stream_if_needed(need_synchronize)
398 def post_backward_for_comm_fusion(self):
399 """Drive the fused gradient-reduction pipeline for sharded params."""
400 self.reduce_params()
401 comm_ctx = get_comm_ctx()
402 if comm_ctx.all_reduce_param_group is not None:
403 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad()
404 comm_ctx.all_reduce_param_group = None
405 if comm_ctx.pre_param_group is not None:
406 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce()
407 comm_ctx.pre_param_group = None
408 if self.param_group is not None:
409 self.param_group.foreach_reduce(
410 reduce_scatter_reduce_op=self.reduce_op_type,
411 needs_avg_div=self._need_div,
412 )
413 self._allreduce_replicate_params()
415 def _post_backward_without_reduce(self):
416 """Finish backward when gradient communication is disabled."""
417 if self.reshard_after_backward:
418 self.shard()
419 for hsdp_param in self._iter_managed_params():
420 hsdp_param.to_accumulated_grad_if_needed()
422 def _should_run_all_reduce(self, hsdp_param) -> bool:
423 """Whether the current parameter should issue an all-reduce in this backward pass."""
424 return self.requires_all_reduce and hsdp_param.dp_size > 1
426 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param):
427 """Queue the standard FSDP/HSDP reduction path."""
428 hsdp_param.reduce_scatter_grad(
429 async_op=True,
430 dtype=self._reduce_dtype,
431 reduce_op=self.reduce_op_type
432 )
433 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype, self._need_div))
434 if not self._should_run_all_reduce(hsdp_param):
435 return
436 reduced_grad = hsdp_param.reduce_scatter_output()
437 if (
438 HSDPState.pre_reduce_scatter_params
439 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param
440 ):
441 HSDPState.pre_reduce_scatter_params.pop()
442 hsdp_param.clear_reduce_scatter_output()
443 self._div_if_needed(reduced_grad, hsdp_param.shard_size, self._need_div)
444 hsdp_param.all_reduce_grad(
445 grad=reduced_grad,
446 dtype=self._reduce_dtype,
447 async_op=True,
448 reduce_op=self.reduce_op_type,
449 )
450 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div))
452 def _queue_compat_all_reduce(self, hsdp_param):
453 """Queue the compatibility all-reduce path without FSDP sharding."""
454 if not self._should_run_all_reduce(hsdp_param):
455 return
456 hsdp_param.all_reduce_grad(
457 grad=self._get_pending_unsharded_grad(hsdp_param),
458 dtype=self._reduce_dtype,
459 async_op=True,
460 reduce_op=self.reduce_op_type,
461 )
462 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div))
464 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool:
465 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly."""
466 return (
467 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT
468 and hsdp_param.enable_fsdp_shard
469 and not hsdp_param.is_sharded
470 and hsdp_param.shard_size == 1
471 and hsdp_param.sharded_param.requires_grad
472 and self._should_run_all_reduce(hsdp_param)
473 and self._get_local_sharded_grad(hsdp_param) is not None
474 )
476 def _queue_direct_compat_all_reduce(self, hsdp_param):
477 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``."""
478 grad = self._get_local_sharded_grad(hsdp_param)
479 if grad is None:
480 return
481 reduced_grad = _to_dtype_if_needed(grad, self._reduce_dtype)
482 reduce_group_info = getattr(hsdp_param, "unsharded_group_info", None)
483 reduce_group = reduce_group_info.group if reduce_group_info is not None else None
484 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1
485 handle = None
486 if reduce_group_size > 1:
487 if reduce_group is None:
488 raise RuntimeError("Expected a valid unsharded all-reduce group when rank_size > 1")
489 handle = dist.all_reduce(
490 reduced_grad,
491 group=reduce_group,
492 op=self.reduce_op_type,
493 async_op=True,
494 )
495 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.append(
496 (handle, reduced_grad, grad, reduce_group_size, self._need_div)
497 )
499 def post_backward(self, *_):
500 for hsdp_param in self._iter_managed_params():
501 hsdp_param.accumulate_unsharded_grad_if_needed()
502 if not self.reduce_grads:
503 self._post_backward_without_reduce()
504 return
505 if not self.comm_fusion:
506 self.reduce_params()
507 self._allreduce_replicate_params()
508 for hsdp_param in self.hsdp_params:
509 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
510 if self._can_direct_all_reduce_compat_grad(hsdp_param):
511 self._queue_direct_compat_all_reduce(hsdp_param)
512 continue
513 if not hsdp_param.sharded_param.requires_grad:
514 continue
515 if not self._has_pending_unsharded_grad(hsdp_param):
516 continue
517 if hsdp_param.shard_size > 1:
518 self._queue_reduce_scatter_then_all_reduce(hsdp_param)
519 elif self._should_run_all_reduce(hsdp_param):
520 self._queue_compat_all_reduce(hsdp_param)
521 else:
522 need_synchronize = hsdp_param.apply_reduced_grad(
523 self._get_pending_unsharded_grad(hsdp_param),
524 self._orig_dtype,
525 )
526 self._synchronize_current_stream_if_needed(need_synchronize)
527 self._finish_ignored_allreduce()
528 else:
529 self.post_backward_for_comm_fusion()
530 if self.reshard_after_backward:
531 self.shard()
533 def set_requires_grad_sync(self, requires_grad_sync):
534 """set requires grad sync flag to control gradient sync."""
535 self.reduce_grads = requires_grad_sync
537 def set_reduce_op_type(self, reduce_op_type: str):
538 """set reduce op type for gradient reduction."""
539 fsdp_support_reduce_op = {
540 "sum": ops.ReduceOp.SUM,
541 "avg": ops.ReduceOp.SUM,
542 }
543 if reduce_op_type not in fsdp_support_reduce_op:
544 raise ValueError(
545 f"Unsupported reduce op type {reduce_op_type}, "
546 f"supported types are {list(fsdp_support_reduce_op.keys())}")
547 self._need_div = reduce_op_type == "avg"
548 reduce_op: str = reduce_op_type.lower().strip()
549 self.reduce_op_type = fsdp_support_reduce_op[reduce_op]