Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_state.py: 51%
65 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP cell state"""
16from typing import List, Tuple, Union
18from hyper_parallel.platform import get_platform
19from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2
20from hyper_parallel.core.fully_shard.hsdp_utils import HSDPConfigV2
22platform = get_platform()
25class HSDPState:
26 """HSDP state for cell"""
27 # Record pending per-parameter reduce-scatter/all-reduce work across
28 # fully_shard states so later backward hooks/root drains can materialize
29 # gradients launched by earlier states.
30 pre_reduce_scatter_params = []
31 pre_all_reduce_params = []
33 def __init__(self, cell: Union[platform.Module, Tuple[platform.Module, ...]], mesh_info,
34 config: HSDPConfigV2, platform_impl, device=None):
35 """
36 Initialize HSDPState.
38 Args:
39 cell (platform.Module or Tuple[platform.Module, ...]): The module(s) whose parameters
40 are managed by this state. When a tuple is passed, all modules are
41 treated as one FSDP unit.
42 mesh_info: Mesh topology for shard/replicate dimensions.
43 config (HSDPConfigV2): HSDP configuration (mesh, mp_policy, offload_policy, etc.).
44 platform_impl: Platform abstraction layer (Torch or MindSpore).
45 device (torch.device, optional): Target device for parameters.
46 """
47 self.modules = (cell,) if isinstance(cell, platform.Module) else tuple(cell)
48 self.cell = self.modules[0]
49 self.mesh_info = mesh_info
50 self.config = config
51 self.mp_policy = config.mp_policy
52 self.offload_policy = config.offload_policy
53 self.platform = platform_impl
54 self.device = device
55 self.hsdp_params: List[HSDPParamV2] = []
56 self.sharded_hsdp_params: List[HSDPParamV2] = []
57 self.replicate_params: List[HSDPParamV2] = []
58 self._move_states_to_device()
59 self._init_hsdp_params()
60 self.is_shard = True
61 self.module_name = None
63 def _init_hsdp_params(self):
64 """init hsdp parameters for cell"""
65 raise NotImplementedError("HSDPState subclasses must implement _init_hsdp_params")
67 def _move_states_to_device(self):
68 """move states to device"""
69 raise NotImplementedError("HSDPState subclasses must implement _move_states_to_device")
71 def shard(self, shard_replicate: bool = True):
72 """change parameters to sharded state"""
73 if self.is_shard:
74 return
76 for param in self.sharded_hsdp_params:
77 param.to_sharded()
78 if shard_replicate:
79 for param in self.replicate_params:
80 param.to_sharded()
81 self.is_shard = True
82 return
84 def unshard(self, async_op=False, unshard_replicate: bool = True):
85 """change parameters to unsharded state"""
86 if not self.is_shard:
87 return
89 if unshard_replicate:
90 for param in self.replicate_params:
91 param.unshard(async_op)
92 if self.config.comm_fusion and self.param_group is not None:
93 self.param_group.unshard(async_op)
94 else:
95 for param in self.sharded_hsdp_params:
96 param.unshard(async_op)
97 if not async_op:
98 self.wait_for_unshard(unshard_replicate)
100 def prefetch(self, unshard_replicate: bool = True):
101 """prefetch unsharded parameters"""
102 self.unshard(async_op=True, unshard_replicate=unshard_replicate)
104 def wait_for_unshard(self, wait_for_replicate: bool = True):
105 """wait for all unshard parameters"""
106 if not self.is_shard:
107 return
108 if wait_for_replicate:
109 for param in self.replicate_params:
110 param.wait_for_unshard()
111 if self.config.comm_fusion and self.param_group is not None:
112 self.param_group.wait_for_unshard()
113 else:
114 for param in self.sharded_hsdp_params:
115 param.wait_for_unshard()
116 self.is_shard = False
118 def _iter_managed_params(self):
119 """Return all fully_shard-managed parameters, including replicate_params."""
120 return [*self.hsdp_params, *self.replicate_params]