Coverage for hyper_parallel / core / fully_shard / hsdp_state.py: 37%
131 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP cell state"""
16from typing import List
17from hyper_parallel.core.fully_shard.hsdp_param import HSDPParamV2
18from hyper_parallel.core.fully_shard.hsdp_param_buffer import HSDPParamBuffer
19from hyper_parallel.core.fully_shard.hsdp_grad_buffer import HSDPGradBuffer
20from hyper_parallel.core.fully_shard.hsdp_utils import HSDPConfigV2
22class HSDPState:
23 """HSDP state for cell"""
24 def __init__(self, cell, mesh_info, config: HSDPConfigV2, platform, device=None):
25 self.cell = cell
26 self.mesh_info = mesh_info
27 self.config = config
28 self.mp_policy = config.mp_policy
29 self.offload_policy = config.offload_policy
30 self.platform = platform
31 self.device = device
32 self.hsdp_params: List[HSDPParamV2] = []
33 self.sharded_hsdp_params: List[HSDPParamV2] = []
34 self.param_buffers = []
35 self.grad_buffers = []
36 self._move_states_to_device()
37 self._init_hsdp_params()
38 self._init_param_buffers()
39 self._init_grad_buffers()
40 self.is_shard = True
42 def _init_hsdp_params(self):
43 """init hsdp parameters for cell"""
44 raise NotImplementedError("HSDPState subclasses must implement _init_hsdp_params")
46 def _move_states_to_device(self):
47 """move states to device"""
48 raise NotImplementedError("HSDPState subclasses must implement _move_states_to_device")
50 def _init_param_buffers(self):
51 """init param buffers"""
52 if not self.config.comm_fusion:
53 return
55 group_to_buffer = {}
56 for hsdp_param in self.sharded_hsdp_params:
57 param_buffer_key = hsdp_param.sharded_group_info.group_name + str(hsdp_param.param.dtype)
58 if param_buffer_key not in group_to_buffer:
59 buffer = HSDPParamBuffer(self.config, hsdp_param, self.platform)
60 buffer.add_param(hsdp_param)
61 group_to_buffer[param_buffer_key] = buffer
62 else:
63 buffer = group_to_buffer[param_buffer_key]
64 buffer.add_param(hsdp_param)
65 self.param_buffers = list(group_to_buffer.values())
66 for buffer in self.param_buffers:
67 buffer.init()
69 def _init_grad_buffers(self):
70 """init grad buffers"""
71 if not self.config.grad_fusion:
72 return
74 bucket_infos = {}
76 def get_bucket_key(buffer_key, hsdp_param):
77 if self.config.bucket_size < 0:
78 return buffer_key
79 param_size = hsdp_param.param.numel() * self.platform.get_param_type_size(hsdp_param.param)
80 bucket_info = bucket_infos.get(buffer_key, None)
81 if bucket_info is None:
82 bucket_info = [0, param_size]
83 bucket_infos[buffer_key] = bucket_info
84 else:
85 bucket_size = bucket_info[1] + param_size
86 if bucket_size > self.config.bucket_size:
87 bucket_info[0] = bucket_info[0] + 1
88 bucket_info[1] = param_size
89 return buffer_key + '_' + str(bucket_info[0])
91 self.param_to_buffer = {}
92 group_to_buffer = {}
93 for hsdp_param in self.hsdp_params:
94 if not hsdp_param.param.requires_grad:
95 continue
96 buffer_key = hsdp_param.sharded_group_info.group_name + hsdp_param.unsharded_group_info.group_name \
97 + str(hsdp_param.param.dtype)
98 bucket_key = get_bucket_key(buffer_key, hsdp_param)
99 if bucket_key not in group_to_buffer:
100 buffer = HSDPGradBuffer(self.config, hsdp_param, self.platform)
101 group_to_buffer[bucket_key] = buffer
102 else:
103 buffer = group_to_buffer[bucket_key]
104 buffer.add_param(hsdp_param)
105 self.param_to_buffer[hsdp_param] = buffer
106 self.grad_buffers = list(group_to_buffer.values())
107 for buffer in self.grad_buffers:
108 buffer.init()
110 def shard(self):
111 """change parameters to sharded state"""
112 if self.is_shard:
113 return
115 if self.config.comm_fusion:
116 for buffer in self.param_buffers:
117 buffer.to_sharded()
118 else:
119 for param in self.sharded_hsdp_params:
120 param.to_sharded()
121 self.is_shard = True
123 def unshard(self, async_op=False):
124 """change parameters to unsharded state"""
125 if not self.is_shard:
126 return
128 if self.config.comm_fusion:
129 raise ValueError(f"comm_fusion is deprecated, check config.comm_fusion.")
130 for buffer in self.param_buffers:
131 buffer.to_unsharded(async_op=async_op)
132 else:
133 for param in self.sharded_hsdp_params:
134 param.unshard()
135 param.wait_for_unshard()
136 self.is_shard = False
138 def prefetch(self):
139 """prefetch unsharded parameters"""
140 if not self.is_shard:
141 return
142 if self.config.comm_fusion:
143 for buffer in self.param_buffers:
144 buffer.prefetch_unsharded()
145 else:
146 for param in self.sharded_hsdp_params:
147 param.unshard(async_op=True)
150 def wait_for_unsharded(self):
151 """wait for all unsharded parameters"""
152 if not self.is_shard:
153 return
154 if self.config.comm_fusion:
155 for buffer in self.param_buffers:
156 if buffer.prefetch_handle is not None:
157 buffer.wait_for_unsharded()
158 else:
159 for param in self.sharded_hsdp_params:
160 if param.prefetch_handle is not None:
161 param.wait_for_unsharded()
163 def zero_grads(self):
164 """zero grad or grad buffer"""
165 if not self.config.grad_fusion:
166 for hsdp_param in self.hsdp_params:
167 if not hsdp_param.param.requires_grad:
168 continue
169 hsdp_param.zero_acc_grad()
170 else:
171 for buffer in self.grad_buffers:
172 buffer.zero_grads()
174 def set_grad_ready(self, hsdp_param):
175 """set grad ready"""
176 if not self.config.grad_fusion:
177 return
178 buffer = self.param_to_buffer.get(hsdp_param, None)
179 if buffer is not None:
180 buffer.set_grad_ready()
181 else:
182 raise ValueError(f"param {hsdp_param.param} is not register to buffer.")
184 def set_requires_grad_sync(self, requires_grad_sync):
185 """set requires grad sync flag to control gradient sync."""
186 if not self.config.grad_fusion:
187 return
188 for buffer in self.grad_buffers:
189 buffer.set_requires_grad_sync(requires_grad_sync)