Coverage for hyper_parallel / core / hsdp / hsdp_state.py: 79%
111 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP cell state"""
16from hyper_parallel.core.hsdp.hsdp_param_buffer import HSDPParamBuffer
17from hyper_parallel.core.hsdp.hsdp_grad_buffer import HSDPGradBuffer
20class HSDPState:
21 """HSDP state for cell"""
22 def __init__(self, cell, config, platform):
23 self.cell = cell
24 self.config = config
25 self.platform = platform
26 self.hsdp_params = []
27 self.sharded_hsdp_params = []
28 self.param_buffers = []
29 self.grad_buffers = []
30 self._init_hsdp_params()
31 self._init_param_buffers()
32 self._init_grad_buffers()
33 self.is_shard = True
35 def _init_hsdp_params(self):
36 """init hsdp parameters for cell"""
37 raise NotImplementedError("HSDPState subclasses must implement _init_hsdp_params")
39 def _init_param_buffers(self):
40 """init param buffers"""
41 if not self.config.comm_fusion:
42 return
44 group_to_buffer = {}
45 for hsdp_param in self.sharded_hsdp_params:
46 param_buffer_key = hsdp_param.sharded_group_info.group_name + str(hsdp_param.param.dtype)
47 if param_buffer_key not in group_to_buffer:
48 buffer = HSDPParamBuffer(self.config, hsdp_param, self.platform)
49 buffer.add_param(hsdp_param)
50 group_to_buffer[param_buffer_key] = buffer
51 else:
52 buffer = group_to_buffer[param_buffer_key]
53 buffer.add_param(hsdp_param)
54 self.param_buffers = list(group_to_buffer.values())
55 for buffer in self.param_buffers:
56 buffer.init()
58 def _init_grad_buffers(self):
59 """init grad buffers"""
60 if not self.config.grad_fusion:
61 return
63 bucket_infos = {}
64 def get_bucket_key(buffer_key, hsdp_param):
65 if self.config.bucket_size < 0:
66 return buffer_key
67 param_size = hsdp_param.param.numel() * self.platform.get_param_type_size(hsdp_param.param)
68 bucket_info = bucket_infos.get(buffer_key, None)
69 if bucket_info is None:
70 bucket_info = [0, param_size]
71 bucket_infos[buffer_key] = bucket_info
72 else:
73 bucket_size = bucket_info[1] + param_size
74 if bucket_size > self.config.bucket_size:
75 bucket_info[0] = bucket_info[0] + 1
76 bucket_info[1] = param_size
77 return buffer_key + '_' + str(bucket_info[0])
79 self.param_to_buffer = {}
80 group_to_buffer = {}
81 for hsdp_param in self.hsdp_params:
82 if not hsdp_param.param.requires_grad:
83 continue
84 buffer_key = hsdp_param.sharded_group_info.group_name + hsdp_param.unsharded_group_info.group_name \
85 + str(hsdp_param.param.dtype)
86 bucket_key = get_bucket_key(buffer_key, hsdp_param)
87 if bucket_key not in group_to_buffer:
88 buffer = HSDPGradBuffer(self.config, hsdp_param, self.platform)
89 group_to_buffer[bucket_key] = buffer
90 else:
91 buffer = group_to_buffer[bucket_key]
92 buffer.add_param(hsdp_param)
93 self.param_to_buffer[hsdp_param] = buffer
94 self.grad_buffers = list(group_to_buffer.values())
95 for buffer in self.grad_buffers:
96 buffer.init()
98 def shard(self):
99 """change parameters to sharded state"""
100 if self.is_shard:
101 return
103 if self.config.comm_fusion:
104 for buffer in self.param_buffers:
105 buffer.to_sharded()
106 else:
107 for param in self.sharded_hsdp_params:
108 param.to_sharded()
109 self.is_shard = True
111 def unshard(self):
112 """change parameters to unsharded state"""
113 if not self.is_shard:
114 return
116 if self.config.comm_fusion:
117 for buffer in self.param_buffers:
118 buffer.to_unsharded()
119 else:
120 for param in self.sharded_hsdp_params:
121 param.to_unsharded()
122 self.is_shard = False
124 def prefetch(self):
125 """prefetch unsharded parameters"""
126 if not self.is_shard:
127 return
128 if self.config.comm_fusion:
129 for buffer in self.param_buffers:
130 buffer.prefetch_unsharded()
131 else:
132 for param in self.sharded_hsdp_params:
133 param.prefetch_unsharded()
135 def zero_grads(self):
136 """zero grad or grad buffer"""
137 if not self.config.grad_fusion:
138 for hsdp_param in self.hsdp_params:
139 if not hsdp_param.param.requires_grad:
140 continue
141 hsdp_param.zero_acc_grad()
142 else:
143 for buffer in self.grad_buffers:
144 buffer.zero_grads()
146 def set_grad_ready(self, hsdp_param):
147 """set grad ready"""
148 if not self.config.grad_fusion:
149 return
150 buffer = self.param_to_buffer.get(hsdp_param, None)
151 if buffer is not None:
152 buffer.set_grad_ready()
153 else:
154 raise ValueError(f"param {hsdp_param.param} is not register to buffer.")
156 def set_requires_grad_sync(self, requires_grad_sync):
157 """set requires grad sync flag to control gradient sync."""
158 if not self.config.grad_fusion:
159 return
160 for buffer in self.grad_buffers:
161 buffer.set_requires_grad_sync(requires_grad_sync)