Coverage for hyper_parallel / core / hsdp / hsdp_param_buffer.py: 85%
66 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP parameter buffer"""
18class HSDPParamBuffer:
19 """
20 HSDP parameter buffer.
21 """
22 def __init__(self, config, init_hsdp_param, platform):
23 self.config = config
24 self.platform = platform
25 self.shard_size = init_hsdp_param.shard_size
26 self.local_rank = init_hsdp_param.hsdp_rank % init_hsdp_param.shard_size
27 self.dtype = init_hsdp_param.param.dtype
28 self.sharded_group_info = init_hsdp_param.sharded_group_info
29 self.device = init_hsdp_param.param.device
30 self.hsdp_params = []
31 self.numel = 0
32 self.sharded_param_buffer = None
33 self.unshared_param_buffer = None
34 self.prefetch_handle = None
35 self.prefetch_data = None
37 def init(self):
38 """init buffer"""
39 self.numel = 0
40 for hsdp_param in self.hsdp_params:
41 start_index = self.numel
42 end_index = start_index + hsdp_param.sharded_param.numel()
43 hsdp_param.param_buffer_start_index = start_index
44 hsdp_param.param_buffer_end_index = end_index
45 self.numel = end_index
46 self._init_param_buffer()
48 def _init_param_buffer(self):
49 """init params buffer"""
50 self.sharded_param_buffer = self.platform.new_tensor((self.numel,), self.dtype, self.device)
51 for hsdp_param in self.hsdp_params:
52 start_index = hsdp_param.param_buffer_start_index
53 end_index = hsdp_param.param_buffer_end_index
54 self.sharded_param_buffer[start_index:end_index] = hsdp_param.sharded_param.reshape(-1)
55 local_shape = hsdp_param.sharded_param.shape
56 data = self.sharded_param_buffer[start_index:end_index].view(local_shape)
57 hsdp_param.sharded_param_view = data
59 def add_param(self, hsdp_param):
60 """add param to buffer"""
61 self.hsdp_params.append(hsdp_param)
63 def to_sharded(self):
64 """change parameter to sharded state"""
65 for hsdp_param in self.hsdp_params:
66 hsdp_param.sharded_param[:] = hsdp_param.sharded_param_view[:]
67 hsdp_param.to_sharded()
68 self.unshared_param_buffer = None
70 def _update_data_view(self):
71 for hsdp_param in self.hsdp_params:
72 hsdp_param.sharded_param_view[:] = hsdp_param.param[:]
74 def prefetch_unsharded(self):
75 """prefetch unsharded params with async all gather"""
76 if self.prefetch_handle is not None:
77 return
78 self._update_data_view()
80 unshared_param_buffer, handle = self.platform.all_gather_into_tensor(self.sharded_param_buffer,
81 self.sharded_group_info,
82 async_op=True)
83 self.prefetch_data = unshared_param_buffer
84 self.prefetch_handle = handle
86 def to_unsharded(self):
87 """change parameter to unsharded state"""
88 if self.prefetch_handle is not None:
89 self.prefetch_handle.wait()
90 unshared_param_buffer = self.prefetch_data
91 self.prefetch_handle = None
92 self.prefetch_data = None
93 else:
94 self._update_data_view()
95 unshared_param_buffer, _ = self.platform.all_gather_into_tensor(self.sharded_param_buffer,
96 self.sharded_group_info,
97 async_op=True)
98 unshared_param_buffer = unshared_param_buffer.view((self.shard_size, -1))
99 for hsdp_param in self.hsdp_params:
100 start_index = hsdp_param.param_buffer_start_index
101 end_index = hsdp_param.param_buffer_end_index
102 unshared_param_data = unshared_param_buffer[:, start_index:end_index]
103 unshared_param_data = unshared_param_data.view(hsdp_param.param_shape)
104 self.platform.update_param_data(hsdp_param.param, unshared_param_data)
105 self.unshared_param_buffer = unshared_param_buffer