Coverage for hyper_parallel / core / hsdp / api.py: 77%
101 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""hybrid shard data parallel interface"""
16from typing import Optional, Any
17from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel
18from hyper_parallel.platform.platform import PlatformType
19from hyper_parallel.platform import get_platform
20platform = get_platform()
22origin_class_to_extend_class = {}
23optimizer_level_map = {
24 "level1": OptimizerLevel.SHARD_OPT,
25 "level2": OptimizerLevel.SHARD_OPT_GRAD,
26 "level3": OptimizerLevel.SHARD_OPT_GRAD_PARAM,
27}
30class HSDPCell:
31 """
32 The hsdp block of neural networks with hsdp interface.
34 Supported Platforms:
35 ``MindSpore`` ``torch``
36 """
37 # pylint: disable=C0415
38 def hsdp_init(self, platform_type, cell, shard_size, threshold, optimizer_level, enable_grad_accumulation,
39 use_eager_hook, grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size):
40 """init hsdp scheduler."""
41 scheduler_class = None
42 if platform_type == PlatformType.MINDSPORE:
43 from hyper_parallel.platform.mindspore.hsdp.scheduler import MindSporeHSDPScheduler
44 scheduler_class = MindSporeHSDPScheduler
45 else:
46 from hyper_parallel.platform.torch.hsdp.scheduler import TorchHSDPScheduler
47 scheduler_class = TorchHSDPScheduler
49 self.hsdp_scheduler = scheduler_class(cell,
50 shard_size,
51 threshold,
52 optimizer_level,
53 enable_grad_accumulation,
54 grad_scale,
55 use_eager_hook,
56 reduce_dtype,
57 comm_async,
58 comm_fusion,
59 bucket_size)
61 def set_requires_grad_sync(self, requires_grad_sync):
62 r"""
63 set requires grad sync flag.
64 Args:
65 requires_grad_sync(bool): requires_grad_sync is used to control gradient sync process.
66 Raises:
67 ValueError: If `requires_grad_sync` is not bool.
68 """
69 if not isinstance(requires_grad_sync, bool):
70 raise ValueError(f"requires_grad_sync must be bool but got {requires_grad_sync}.")
71 if not hasattr(self, "hsdp_scheduler"):
72 raise ValueError("call hsdp interface first.")
74 for _, cell in platform.get_cells_and_names(self):
75 if isinstance(cell, HSDPCell):
76 cell.hsdp_scheduler.set_requires_grad_sync(requires_grad_sync)
78 def zero_grads(self):
79 """zero accumunication grads"""
80 if not hasattr(self, "hsdp_scheduler"):
81 raise ValueError("call hsdp interface first.")
83 for _, cell in platform.get_cells_and_names(self):
84 if isinstance(cell, HSDPCell):
85 cell.hsdp_scheduler.zero_grads()
87 def set_forward_prefetch_cells(self, hsdp_cell_list):
88 """set forward prefetch cell list to prefetch all gather for unsharded parameters"""
89 if not isinstance(hsdp_cell_list, (tuple, list)):
90 raise ValueError("hsdp_cell_list must be HSDPCell list")
91 for cell in hsdp_cell_list:
92 if not isinstance(cell, HSDPCell):
93 raise ValueError(f"hsdp_cell_list must be HSDPCell list but got {type(cell)} in list.")
94 if not hasattr(self, "hsdp_scheduler"):
95 raise ValueError("call hsdp interface first.")
96 self.hsdp_scheduler.set_forward_prefetch_cells(hsdp_cell_list)
98 def set_backward_prefetch_cells(self, hsdp_cell_list):
99 """set backward prefetch cell list to prefetch all gather for unsharded parameters"""
100 if not isinstance(hsdp_cell_list, (tuple, list)):
101 raise ValueError("hsdp_cell_list must be HSDPCell list")
102 for cell in hsdp_cell_list:
103 if not isinstance(cell, HSDPCell):
104 raise ValueError(f"hsdp_cell_list must be HSDPCell list but got {type(cell)} in list.")
105 if not hasattr(self, "hsdp_scheduler"):
106 raise ValueError("call hsdp interface first.")
107 self.hsdp_scheduler.set_backward_prefetch_cells(hsdp_cell_list)
109def _extend_cell_with_hsdp_interface(cell):
110 """extend Cell with HSDPCell interface"""
111 origin_class = cell.__class__
112 extend_class = origin_class_to_extend_class.get(origin_class, None)
113 if extend_class is None:
114 extend_class = type(f"HSDP{origin_class.__name__}", (HSDPCell, origin_class), {})
115 origin_class_to_extend_class[origin_class] = extend_class
116 cell.__class__ = extend_class
118# pylint: disable=C0415
119def _check_cell_valid(platform_type, cell):
120 """check cell valid"""
121 if platform_type == PlatformType.MINDSPORE:
122 from mindspore.nn.cell import Cell
123 if not isinstance(cell, Cell):
124 raise ValueError(f"cell's type must be nn.cell but got {type(cell)}.")
125 else:
126 from torch.nn import Module
127 if not isinstance(cell, Module):
128 raise ValueError(f"cell's type must be nn.Module but got {type(cell)}.")
130# pylint: disable=C0415
131def _check_hsdp_input_valid(platform_type, cell, shard_size, threshold, optimizer_level, enable_grad_accumulation,
132 use_eager_hook, grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size):
133 """check hsdp input valid"""
134 _check_cell_valid(platform_type, cell)
135 if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1):
136 raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.")
137 if not isinstance(threshold, int) or threshold < 0:
138 raise ValueError(f"threshold must be a positive integer or 0, but got {threshold}.")
139 if optimizer_level not in ["level1", "level2", "level3"]:
140 raise ValueError(f"Optimizer level should in ['level1', 'level2', 'level3'], but got {optimizer_level}.")
141 if not isinstance(enable_grad_accumulation, bool):
142 raise ValueError(f"enable_grad_accumulation must be bool but got {enable_grad_accumulation}.")
143 if not isinstance(grad_scale, float):
144 raise ValueError(f"grad_scale must be float but got {grad_scale}.")
145 if not isinstance(use_eager_hook, bool):
146 raise ValueError(f"use_eager_hook must be bool but got {use_eager_hook}.")
147 if platform_type == PlatformType.MINDSPORE:
148 from mindspore._c_expression.typing import Type
149 if reduce_dtype is not None and not isinstance(reduce_dtype, Type):
150 raise ValueError(f"reduce_dtype must be mindspore.dtype but got {reduce_dtype}.")
151 else:
152 import torch
153 if reduce_dtype is not None and not isinstance(reduce_dtype, torch.dtype):
154 raise ValueError(f"reduce_dtype must be torch.dtype but got {reduce_dtype}.")
155 if not isinstance(comm_async, bool):
156 raise ValueError(f"comm_async must be bool but got {comm_async}.")
157 if not isinstance(comm_fusion, bool):
158 raise ValueError(f"comm_fusion must be bool but got {comm_fusion}.")
159 if not isinstance(bucket_size, int) or (bucket_size < 0 and bucket_size != -1):
160 raise ValueError(f"bucket_size must be a positive integer or 0, but got {bucket_size}.")
162def hsdp(
163 cell,
164 shard_size: Optional[int] = -1,
165 threshold: Optional[int] = 64,
166 optimizer_level: Optional[str] = "level1",
167 enable_grad_accumulation: Optional[bool] = False,
168 use_eager_hook: Optional[bool] = True,
169 grad_scale: Optional[float] = 1.0,
170 reduce_dtype: Optional[Any] = None,
171 comm_async: Optional[bool] = False,
172 comm_fusion: Optional[bool] = False,
173 bucket_size: Optional[int] = -1
174):
175 r"""
176 apply hybrid sharded data parallel.
178 Args:
179 cell(Cell|Module): The cell to apply hsdp.
180 shard_size (int, optional): Set the optimizer weight shard group size if you want to specific the
181 maximum group size across devices. The numerical range can be (0, device_num] or -1. Default value
182 is -1, which means the optimizer weight shard group size will be
183 the data parallel group of each parameter.
184 threshold (int, optional): Set the threshold of parallel optimizer. When parallel optimizer is
185 enabled, parameters with size smaller than this threshold will not be
186 sharded across the devices. Parameter size = shape[0] \* ... \*
187 shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64.
188 optimizer_level (str, optional): optimizer_level configuration is used to specify
189 the splitting level for optimizer sharding. It is important to note that the implementation
190 of optimizer sharding in static graph is inconsistent with dynamic graph like megatron,
191 but the memory optimization effect is the same.
192 It must be one of [ ``level1``, ``level2``, ``level3`` ]. Default: ``level1``.
194 - level1:
195 Splitting is performed on weights and optimizer state.
196 - level2:
197 Splitting is performed on weights, optimizer state, and gradients.
198 - level3:
199 Splitting is performed on weights, optimizer state,
200 gradients, additionally, before the backward pass, the weights are further applied with
201 allgather communication to release the memory used by the forward pass allgather.
202 enable_grad_accumulation (bool, optional): enable gradient accumulation. When gradient accumulation is
203 enable, gradient synchronization should be explicitly called by `set_requires_grad_sync` interface.
204 use_eager_hook (bool, optional): Controls whether to enable eager hook behavior. Default: ``True``.
206 - Set to `True` for **eager mode** (both MindSpore and PyTorch).
207 - Set to `False` for **MindSpore Graph mode** (static graph).
208 grad_scale (float, optional): gradient will scale with grad_scale.
209 reduce_dtype (float, optional): gradient reduce dtype. Default value is None, which means gradient
210 will be reduced with its origin dtype.
211 comm_async (bool, optional): reduce gradient with async communication op for communication overlap.
212 When comm_async is enable, ``hsdp_sync_stream`` should be called before using generated
213 gradient. Default value is False, which means gradient will be reduced with sync communication op.
214 comm_fusion (bool, optional): fuse forward parameter allgathers and backward gradient reducescatters or
215 allreduces into buffers communication to reduce the number of communication op. ``bucket_size` will
216 further control the size of backward gradient reduce buffer size.
217 Default value is False, which means communication op is not fused and will run one by one.
218 bucket_size (int, optional): bucket_size is used to control the size of comm fusion buffer. Unit: KB.
219 Default value is -1, which means gradient will be fused into a buffer. When value is 0, which means
220 gradients will not be fused, each gradient acts as a buffer.
222 Raises:
223 ValueError: If the `cell` is not a cell.
224 ValueError: If the `shard_size` is not a positive integer or -1.
225 ValueError: If `threshold` is not a positive integer or 0.
226 ValueError: If `optimizer_level` is not one of the [ ``level1``, ``level2``, ``level3`` ].
227 ValueError: If `enable_grad_accumulation` is not bool.
228 ValueError: If `grad_scale` is not float.
229 ValueError: If `reduce_dtype` is not mindspore.dtype.
230 ValueError: If `comm_async` is not bool.
231 ValueError: If `comm_fusion` is not bool.
232 ValueError: If the `bucket_size` is not a positive integer or -1.
233 """
234 platform_type = platform.platform_type
236 _check_hsdp_input_valid(
237 platform_type,
238 cell,
239 shard_size,
240 threshold,
241 optimizer_level,
242 enable_grad_accumulation,
243 use_eager_hook,
244 grad_scale,
245 reduce_dtype,
246 comm_async,
247 comm_fusion,
248 bucket_size
249 )
250 optimizer_level = optimizer_level_map.get(optimizer_level)
251 _extend_cell_with_hsdp_interface(cell)
252 cell.hsdp_init(
253 platform_type,
254 cell,
255 shard_size,
256 threshold * 1024,
257 optimizer_level,
258 enable_grad_accumulation,
259 use_eager_hook,
260 grad_scale,
261 reduce_dtype,
262 comm_async,
263 comm_fusion,
264 bucket_size * 1024
265 )
266 return cell
268def hsdp_sync_stream():
269 """wait for hsdp gradient handle to be completed"""
270 if platform is None:
271 return
272 platform.wait_grad_handle()