Coverage for hyper_parallel / platform / mindspore / hsdp / state.py: 97%
36 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore HSDP cell state"""
16from mindspore.common.api import _no_grad
17from mindspore import jit_class
18from hyper_parallel.core.hsdp.hsdp_state import HSDPState
19from hyper_parallel.platform.mindspore.hsdp.param import MindSporeHSDPParam
22@jit_class
23class MindSporeHSDPState(HSDPState):
24 """MindSpore HSDP cell state"""
26 def _init_hsdp_params(self):
27 """init hsdp parameters for cell"""
28 cells = self.cell.cells_and_names()
29 for _, sub_cell in cells:
30 params = sub_cell._params.items() #pylint: disable=W0212
31 for param_name, param in params:
32 if hasattr(param, "has_hsdp_param"):
33 continue
34 hsdp_param = MindSporeHSDPParam(sub_cell, param_name, param, self.config, self.platform)
35 param.has_hsdp_param = True
36 self.hsdp_params.append(hsdp_param)
37 if hsdp_param.sharded:
38 self.sharded_hsdp_params.append(hsdp_param)
40 @_no_grad()
41 def shard(self):
42 """change parameters to sharded state"""
43 super().shard()
45 @_no_grad()
46 def unshard(self):
47 """change parameters to unsharded state"""
48 super().unshard()
50 @_no_grad()
51 def prefetch(self):
52 """prefetch unsharded parameters"""
53 super().prefetch()
55 @_no_grad()
56 def zero_grads(self):
57 """zero grad or grad buffer"""
58 super().zero_grads()
60 @_no_grad()
61 def set_grad_ready(self, hsdp_param):
62 """set grad ready"""
63 super().set_grad_ready(hsdp_param)
65 @_no_grad()
66 def set_requires_grad_sync(self, requires_grad_sync):
67 """set requires grad sync flag to control gradient sync."""
68 super().set_requires_grad_sync(requires_grad_sync)