Coverage for hyper_parallel / core / fully_shard / hsdp_param_buffer.py: 12%
84 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP parameter buffer"""
18class HSDPParamBuffer:
19 """
20 HSDP parameter buffer.
21 """
23 def __init__(self, config, init_hsdp_param, platform):
24 self.config = config
25 self.platform = platform
26 self.shard_size = init_hsdp_param.shard_size
27 self.local_rank = init_hsdp_param.hsdp_rank % init_hsdp_param.shard_size
28 self.dtype = init_hsdp_param.param.dtype
29 self.sharded_group_info = init_hsdp_param.sharded_group_info
30 self.device = init_hsdp_param.param.device
31 self.hsdp_params = []
32 self.numel = 0
33 self.sharded_param_buffer = None
34 self.unshared_param_buffer = None
35 self.prefetch_handle = None
36 self.prefetch_data = None
38 def init(self):
39 """init buffer"""
40 self.numel = 0
41 for hsdp_param in self.hsdp_params:
42 start_index = self.numel
43 end_index = start_index + hsdp_param.sharded_param.numel()
44 hsdp_param.param_buffer_start_index = start_index
45 hsdp_param.param_buffer_end_index = end_index
46 self.numel = end_index
47 self._init_param_buffer()
49 def _init_param_buffer(self):
50 """init params buffer"""
51 self.sharded_param_buffer = self.platform.new_tensor((self.numel,), self.dtype, self.device)
52 for hsdp_param in self.hsdp_params:
53 start_index = hsdp_param.param_buffer_start_index
54 end_index = hsdp_param.param_buffer_end_index
55 self.sharded_param_buffer[start_index:end_index] = hsdp_param.sharded_param.reshape(-1)
56 local_shape = hsdp_param.sharded_param.shape
57 data = self.sharded_param_buffer[start_index:end_index].view(local_shape)
58 hsdp_param.sharded_param_view = data
60 def add_param(self, hsdp_param):
61 """add param to buffer"""
62 self.hsdp_params.append(hsdp_param)
64 def to_sharded(self):
65 """change parameter to sharded state"""
66 for hsdp_param in self.hsdp_params:
67 hsdp_param.sharded_param[:] = hsdp_param.sharded_param_view[:]
68 hsdp_param.to_sharded()
69 self.unshared_param_buffer = None
71 def _update_data_view(self):
72 for hsdp_param in self.hsdp_params:
73 hsdp_param.sharded_param_view[:] = hsdp_param.param[:]
75 def prefetch_unsharded(self):
76 """prefetch unsharded params with async all gather"""
77 if self.prefetch_handle is not None:
78 return
79 self._update_data_view()
81 unshared_param_buffer, handle = self.platform.all_gather_into_tensor(self.sharded_param_buffer,
82 self.sharded_group_info,
83 async_op=True)
84 self.prefetch_data = unshared_param_buffer
85 self.prefetch_handle = handle
87 def to_unsharded(self, async_op=False):
88 """change parameter to unsharded state"""
89 if self.prefetch_handle is not None:
90 self.prefetch_handle.wait()
91 unshared_param_buffer = self.prefetch_data
92 self.prefetch_handle = None
93 self.prefetch_data = None
94 else:
95 self._update_data_view()
96 unshared_param_buffer, handle = self.platform.all_gather_into_tensor(self.sharded_param_buffer,
97 self.sharded_group_info,
98 async_op=async_op)
99 if async_op:
100 self.prefetch_handle = handle
101 self.prefetch_data = unshared_param_buffer
102 return
103 unshared_param_buffer = unshared_param_buffer.view((self.shard_size, -1))
104 for hsdp_param in self.hsdp_params:
105 start_index = hsdp_param.param_buffer_start_index
106 end_index = hsdp_param.param_buffer_end_index
107 unshared_param_data = unshared_param_buffer[:, start_index:end_index]
108 unshared_param_data = unshared_param_data.view(hsdp_param.param_shape)
109 self.platform.update_param_data(hsdp_param.param, unshared_param_data)
110 self.unshared_param_buffer = unshared_param_buffer
112 def wait_for_unsharded(self):
113 """wait for unsharded buffer"""
114 if self.prefetch_handle is not None:
115 self.prefetch_handle.wait()
116 unshared_param_buffer = self.prefetch_data
117 self.prefetch_handle = None
118 self.prefetch_data = None
119 unshared_param_buffer = unshared_param_buffer.view((self.shard_size, -1))
120 for hsdp_param in self.hsdp_params:
121 start_index = hsdp_param.param_buffer_start_index
122 end_index = hsdp_param.param_buffer_end_index
123 unshared_param_data = unshared_param_buffer[:, start_index:end_index]
124 unshared_param_data = unshared_param_data.view(hsdp_param.param_shape)
125 self.platform.update_param_data(hsdp_param.param, unshared_param_data)
126 self.unshared_param_buffer = unshared_param_buffer