Coverage for hyper_parallel / core / fully_shard / hsdp_param.py: 55%

60 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP parameter""" 

16 

17 

18class HSDPParamV2: 

19 """ 

20 HSDP parameter. 

21 """ 

22 

23 def __init__( 

24 self, 

25 param, 

26 module_info, 

27 mesh_info, 

28 post_forward_mesh_info, 

29 shard_placement_fn, 

30 mp_policy, 

31 offload_policy, 

32 threshold, 

33 ): 

34 raise NotImplementedError("HSDP param subclasses must implement __init__") 

35 

36 def _init_sharded_param(self, param, shard_placement_fn): 

37 """add and init sharded param""" 

38 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_param") 

39 

40 def _init_sharded_post_forward_param_metadata(self, param): 

41 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_post_forward_param_metadata") 

42 

43 def init_dtype_attrs(self, mp_policy): 

44 raise NotImplementedError("HSDP param subclasses must implement init_dtype_attrs") 

45 

46 def _init_extensions(self): 

47 raise NotImplementedError("HSDP param subclasses must implement _init_extensions") 

48 

49 def init_all_gather_outputs(self, all_gather_input_numels, all_gather_input_dtypes, world_size, device, force_recreate=False): 

50 raise NotImplementedError("HSDP param subclasses must implement init_all_gather_outputs") 

51 

52 def init_unsharded_param(self): 

53 raise NotImplementedError("HSDP param subclasses must implement init_unsharded_param") 

54 

55 def to_sharded(self): 

56 raise NotImplementedError("HSDP param subclasses must implement to_sharded") 

57 

58 def to_sharded_post_forward(self): 

59 raise NotImplementedError("HSDP param subclasses must implement to_sharded_post_forward") 

60 

61 def to_unsharded(self): 

62 raise NotImplementedError("HSDP param subclasses must implement to_unsharded") 

63 

64 def to_sharded_dtensor(self, tensor): 

65 raise NotImplementedError("HSDP param subclasses must implement to_sharded_dtensor") 

66 

67 def to_sharded_post_forward_dtensor(self, tensor): 

68 raise NotImplementedError("HSDP param subclasses must implement to_sharded_post_forward_dtensor") 

69 

70 def to_accumulated_grad_if_needed(self): 

71 raise NotImplementedError("HSDP param subclasses must implement to_accumulated_grad_if_needed") 

72 

73 def accumulate_unsharded_grad_if_needed(self): 

74 raise NotImplementedError("HSDP param subclasses must implement accumulate_unsharded_grad_if_needed") 

75 

76 def alloc_all_gather_outputs(self): 

77 raise NotImplementedError("HSDP param subclasses must implement alloc_all_gather_outputs") 

78 

79 def free_unsharded_param(self): 

80 raise NotImplementedError("HSDP param subclasses must implement free_unsharded_param") 

81 

82 @property 

83 def all_gather_inputs(self): 

84 raise NotImplementedError("HSDP param subclasses must implement all_gather_inputs") 

85 

86 @property 

87 def unsharded_param(self): 

88 raise NotImplementedError("HSDP param subclasses must implement unsharded_param") 

89 

90 @property 

91 def unsharded_grad_data(self): 

92 raise NotImplementedError("HSDP param subclasses must implement unsharded_grad_data") 

93 

94 @property 

95 def unsharded_accumulated_grad_data(self): 

96 raise NotImplementedError("HSDP param subclasses must implement unsharded_accumulated_grad_data") 

97 

98 @property 

99 def _sharded_local_tensor(self): 

100 raise NotImplementedError("HSDP param subclasses must implement _sharded_local_tensor") 

101 

102 def _get_unsharded_param_data(self, async_op=False): 

103 raise NotImplementedError("HSDP param subclasses must implement _get_unsharded_param_data") 

104 

105 def unshard(self, async_op=False): 

106 raise NotImplementedError("HSDP param subclasses must implement unshard") 

107 

108 def wait_for_unshard(self): 

109 raise NotImplementedError("HSDP param subclasses must implement wait_for_unshard") 

110 

111 def shard(self): 

112 raise NotImplementedError("HSDP param subclasses must implement shard") 

113 

114 def reduce_scatter_grad(self): 

115 raise NotImplementedError("HSDP param subclasses must implement reduce_scatter_grad") 

116 

117 def all_reduce_grad(self): 

118 raise NotImplementedError("HSDP param subclasses must implement all_reduce_grad")