Coverage for hyper_parallel / core / fully_shard / hsdp_param_buffer.py: 12%

84 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP parameter buffer""" 

16 

17 

18class HSDPParamBuffer: 

19 """ 

20 HSDP parameter buffer. 

21 """ 

22 

23 def __init__(self, config, init_hsdp_param, platform): 

24 self.config = config 

25 self.platform = platform 

26 self.shard_size = init_hsdp_param.shard_size 

27 self.local_rank = init_hsdp_param.hsdp_rank % init_hsdp_param.shard_size 

28 self.dtype = init_hsdp_param.param.dtype 

29 self.sharded_group_info = init_hsdp_param.sharded_group_info 

30 self.device = init_hsdp_param.param.device 

31 self.hsdp_params = [] 

32 self.numel = 0 

33 self.sharded_param_buffer = None 

34 self.unshared_param_buffer = None 

35 self.prefetch_handle = None 

36 self.prefetch_data = None 

37 

38 def init(self): 

39 """init buffer""" 

40 self.numel = 0 

41 for hsdp_param in self.hsdp_params: 

42 start_index = self.numel 

43 end_index = start_index + hsdp_param.sharded_param.numel() 

44 hsdp_param.param_buffer_start_index = start_index 

45 hsdp_param.param_buffer_end_index = end_index 

46 self.numel = end_index 

47 self._init_param_buffer() 

48 

49 def _init_param_buffer(self): 

50 """init params buffer""" 

51 self.sharded_param_buffer = self.platform.new_tensor((self.numel,), self.dtype, self.device) 

52 for hsdp_param in self.hsdp_params: 

53 start_index = hsdp_param.param_buffer_start_index 

54 end_index = hsdp_param.param_buffer_end_index 

55 self.sharded_param_buffer[start_index:end_index] = hsdp_param.sharded_param.reshape(-1) 

56 local_shape = hsdp_param.sharded_param.shape 

57 data = self.sharded_param_buffer[start_index:end_index].view(local_shape) 

58 hsdp_param.sharded_param_view = data 

59 

60 def add_param(self, hsdp_param): 

61 """add param to buffer""" 

62 self.hsdp_params.append(hsdp_param) 

63 

64 def to_sharded(self): 

65 """change parameter to sharded state""" 

66 for hsdp_param in self.hsdp_params: 

67 hsdp_param.sharded_param[:] = hsdp_param.sharded_param_view[:] 

68 hsdp_param.to_sharded() 

69 self.unshared_param_buffer = None 

70 

71 def _update_data_view(self): 

72 for hsdp_param in self.hsdp_params: 

73 hsdp_param.sharded_param_view[:] = hsdp_param.param[:] 

74 

75 def prefetch_unsharded(self): 

76 """prefetch unsharded params with async all gather""" 

77 if self.prefetch_handle is not None: 

78 return 

79 self._update_data_view() 

80 

81 unshared_param_buffer, handle = self.platform.all_gather_into_tensor(self.sharded_param_buffer, 

82 self.sharded_group_info, 

83 async_op=True) 

84 self.prefetch_data = unshared_param_buffer 

85 self.prefetch_handle = handle 

86 

87 def to_unsharded(self, async_op=False): 

88 """change parameter to unsharded state""" 

89 if self.prefetch_handle is not None: 

90 self.prefetch_handle.wait() 

91 unshared_param_buffer = self.prefetch_data 

92 self.prefetch_handle = None 

93 self.prefetch_data = None 

94 else: 

95 self._update_data_view() 

96 unshared_param_buffer, handle = self.platform.all_gather_into_tensor(self.sharded_param_buffer, 

97 self.sharded_group_info, 

98 async_op=async_op) 

99 if async_op: 

100 self.prefetch_handle = handle 

101 self.prefetch_data = unshared_param_buffer 

102 return 

103 unshared_param_buffer = unshared_param_buffer.view((self.shard_size, -1)) 

104 for hsdp_param in self.hsdp_params: 

105 start_index = hsdp_param.param_buffer_start_index 

106 end_index = hsdp_param.param_buffer_end_index 

107 unshared_param_data = unshared_param_buffer[:, start_index:end_index] 

108 unshared_param_data = unshared_param_data.view(hsdp_param.param_shape) 

109 self.platform.update_param_data(hsdp_param.param, unshared_param_data) 

110 self.unshared_param_buffer = unshared_param_buffer 

111 

112 def wait_for_unsharded(self): 

113 """wait for unsharded buffer""" 

114 if self.prefetch_handle is not None: 

115 self.prefetch_handle.wait() 

116 unshared_param_buffer = self.prefetch_data 

117 self.prefetch_handle = None 

118 self.prefetch_data = None 

119 unshared_param_buffer = unshared_param_buffer.view((self.shard_size, -1)) 

120 for hsdp_param in self.hsdp_params: 

121 start_index = hsdp_param.param_buffer_start_index 

122 end_index = hsdp_param.param_buffer_end_index 

123 unshared_param_data = unshared_param_buffer[:, start_index:end_index] 

124 unshared_param_data = unshared_param_data.view(hsdp_param.param_shape) 

125 self.platform.update_param_data(hsdp_param.param, unshared_param_data) 

126 self.unshared_param_buffer = unshared_param_buffer