Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / fully_shard / hsdp_state.py: 51%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP cell state""" 

16from typing import List, Tuple, Union 

17 

18from hyper_parallel.platform import get_platform 

19from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2 

20from hyper_parallel.core.fully_shard.hsdp_utils import HSDPConfigV2 

21 

22platform = get_platform() 

23 

24 

25class HSDPState: 

26 """HSDP state for cell""" 

27 # Record pending per-parameter reduce-scatter/all-reduce work across 

28 # fully_shard states so later backward hooks/root drains can materialize 

29 # gradients launched by earlier states. 

30 pre_reduce_scatter_params = [] 

31 pre_all_reduce_params = [] 

32 

33 def __init__(self, cell: Union[platform.Module, Tuple[platform.Module, ...]], mesh_info, 

34 config: HSDPConfigV2, platform_impl, device=None): 

35 """ 

36 Initialize HSDPState. 

37 

38 Args: 

39 cell (platform.Module or Tuple[platform.Module, ...]): The module(s) whose parameters 

40 are managed by this state. When a tuple is passed, all modules are 

41 treated as one FSDP unit. 

42 mesh_info: Mesh topology for shard/replicate dimensions. 

43 config (HSDPConfigV2): HSDP configuration (mesh, mp_policy, offload_policy, etc.). 

44 platform_impl: Platform abstraction layer (Torch or MindSpore). 

45 device (torch.device, optional): Target device for parameters. 

46 """ 

47 self.modules = (cell,) if isinstance(cell, platform.Module) else tuple(cell) 

48 self.cell = self.modules[0] 

49 self.mesh_info = mesh_info 

50 self.config = config 

51 self.mp_policy = config.mp_policy 

52 self.offload_policy = config.offload_policy 

53 self.platform = platform_impl 

54 self.device = device 

55 self.hsdp_params: List[HSDPParamV2] = [] 

56 self.sharded_hsdp_params: List[HSDPParamV2] = [] 

57 self.replicate_params: List[HSDPParamV2] = [] 

58 self._move_states_to_device() 

59 self._init_hsdp_params() 

60 self.is_shard = True 

61 self.module_name = None 

62 

63 def _init_hsdp_params(self): 

64 """init hsdp parameters for cell""" 

65 raise NotImplementedError("HSDPState subclasses must implement _init_hsdp_params") 

66 

67 def _move_states_to_device(self): 

68 """move states to device""" 

69 raise NotImplementedError("HSDPState subclasses must implement _move_states_to_device") 

70 

71 def shard(self, shard_replicate: bool = True): 

72 """change parameters to sharded state""" 

73 if self.is_shard: 

74 return 

75 

76 for param in self.sharded_hsdp_params: 

77 param.to_sharded() 

78 if shard_replicate: 

79 for param in self.replicate_params: 

80 param.to_sharded() 

81 self.is_shard = True 

82 return 

83 

84 def unshard(self, async_op=False, unshard_replicate: bool = True): 

85 """change parameters to unsharded state""" 

86 if not self.is_shard: 

87 return 

88 

89 if unshard_replicate: 

90 for param in self.replicate_params: 

91 param.unshard(async_op) 

92 if self.config.comm_fusion and self.param_group is not None: 

93 self.param_group.unshard(async_op) 

94 else: 

95 for param in self.sharded_hsdp_params: 

96 param.unshard(async_op) 

97 if not async_op: 

98 self.wait_for_unshard(unshard_replicate) 

99 

100 def prefetch(self, unshard_replicate: bool = True): 

101 """prefetch unsharded parameters""" 

102 self.unshard(async_op=True, unshard_replicate=unshard_replicate) 

103 

104 def wait_for_unshard(self, wait_for_replicate: bool = True): 

105 """wait for all unshard parameters""" 

106 if not self.is_shard: 

107 return 

108 if wait_for_replicate: 

109 for param in self.replicate_params: 

110 param.wait_for_unshard() 

111 if self.config.comm_fusion and self.param_group is not None: 

112 self.param_group.wait_for_unshard() 

113 else: 

114 for param in self.sharded_hsdp_params: 

115 param.wait_for_unshard() 

116 self.is_shard = False 

117 

118 def _iter_managed_params(self): 

119 """Return all fully_shard-managed parameters, including replicate_params.""" 

120 return [*self.hsdp_params, *self.replicate_params]