Coverage for hyper_parallel / core / hsdp / hsdp_param_buffer.py: 85%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP parameter buffer""" 

16 

17 

18class HSDPParamBuffer: 

19 """ 

20 HSDP parameter buffer. 

21 """ 

22 def __init__(self, config, init_hsdp_param, platform): 

23 self.config = config 

24 self.platform = platform 

25 self.shard_size = init_hsdp_param.shard_size 

26 self.local_rank = init_hsdp_param.hsdp_rank % init_hsdp_param.shard_size 

27 self.dtype = init_hsdp_param.param.dtype 

28 self.sharded_group_info = init_hsdp_param.sharded_group_info 

29 self.device = init_hsdp_param.param.device 

30 self.hsdp_params = [] 

31 self.numel = 0 

32 self.sharded_param_buffer = None 

33 self.unshared_param_buffer = None 

34 self.prefetch_handle = None 

35 self.prefetch_data = None 

36 

37 def init(self): 

38 """init buffer""" 

39 self.numel = 0 

40 for hsdp_param in self.hsdp_params: 

41 start_index = self.numel 

42 end_index = start_index + hsdp_param.sharded_param.numel() 

43 hsdp_param.param_buffer_start_index = start_index 

44 hsdp_param.param_buffer_end_index = end_index 

45 self.numel = end_index 

46 self._init_param_buffer() 

47 

48 def _init_param_buffer(self): 

49 """init params buffer""" 

50 self.sharded_param_buffer = self.platform.new_tensor((self.numel,), self.dtype, self.device) 

51 for hsdp_param in self.hsdp_params: 

52 start_index = hsdp_param.param_buffer_start_index 

53 end_index = hsdp_param.param_buffer_end_index 

54 self.sharded_param_buffer[start_index:end_index] = hsdp_param.sharded_param.reshape(-1) 

55 local_shape = hsdp_param.sharded_param.shape 

56 data = self.sharded_param_buffer[start_index:end_index].view(local_shape) 

57 hsdp_param.sharded_param_view = data 

58 

59 def add_param(self, hsdp_param): 

60 """add param to buffer""" 

61 self.hsdp_params.append(hsdp_param) 

62 

63 def to_sharded(self): 

64 """change parameter to sharded state""" 

65 for hsdp_param in self.hsdp_params: 

66 hsdp_param.sharded_param[:] = hsdp_param.sharded_param_view[:] 

67 hsdp_param.to_sharded() 

68 self.unshared_param_buffer = None 

69 

70 def _update_data_view(self): 

71 for hsdp_param in self.hsdp_params: 

72 hsdp_param.sharded_param_view[:] = hsdp_param.param[:] 

73 

74 def prefetch_unsharded(self): 

75 """prefetch unsharded params with async all gather""" 

76 if self.prefetch_handle is not None: 

77 return 

78 self._update_data_view() 

79 

80 unshared_param_buffer, handle = self.platform.all_gather_into_tensor(self.sharded_param_buffer, 

81 self.sharded_group_info, 

82 async_op=True) 

83 self.prefetch_data = unshared_param_buffer 

84 self.prefetch_handle = handle 

85 

86 def to_unsharded(self): 

87 """change parameter to unsharded state""" 

88 if self.prefetch_handle is not None: 

89 self.prefetch_handle.wait() 

90 unshared_param_buffer = self.prefetch_data 

91 self.prefetch_handle = None 

92 self.prefetch_data = None 

93 else: 

94 self._update_data_view() 

95 unshared_param_buffer, _ = self.platform.all_gather_into_tensor(self.sharded_param_buffer, 

96 self.sharded_group_info, 

97 async_op=True) 

98 unshared_param_buffer = unshared_param_buffer.view((self.shard_size, -1)) 

99 for hsdp_param in self.hsdp_params: 

100 start_index = hsdp_param.param_buffer_start_index 

101 end_index = hsdp_param.param_buffer_end_index 

102 unshared_param_data = unshared_param_buffer[:, start_index:end_index] 

103 unshared_param_data = unshared_param_data.view(hsdp_param.param_shape) 

104 self.platform.update_param_data(hsdp_param.param, unshared_param_data) 

105 self.unshared_param_buffer = unshared_param_buffer