Coverage for hyper_parallel / platform / mindspore / hsdp / param.py: 100%
29 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore HSDP parameter"""
16from mindspore import ops, jit_class
17from mindspore.common.parameter import Parameter
18from mindspore.common.initializer import initializer
19from hyper_parallel.core.dtensor import DTensor
20from hyper_parallel.core.hsdp.hsdp_param import HSDPParam
23@jit_class
24class MindSporeHSDPParam(HSDPParam):
25 """
26 MindSpore HSDP parameter.
27 """
28 def _init_sharded_param(self):
29 """add and init sharded param"""
30 if not self.param.has_init:
31 slice_index = self.hsdp_rank % self.shard_size
32 local_param = self.platform.get_param_local_data(self.param)
33 param_slice = ops.split(local_param, local_param.shape[0] // self.shard_size)[slice_index] + 0
34 self.platform.update_param_data(self.param, param_slice)
36 # TODO: refactor, mindspore jit ast does not need this
37 self.sharded_param = Parameter(param_slice,
38 name="sharded_"+self.param.name,
39 requires_grad=self.param.requires_grad)
40 else:
41 dp_slice_index = self.hsdp_rank % self.shard_size
42 data_slice_index = self.tp_rank * self.shard_size + dp_slice_index
43 if isinstance(self.param.init_mode, DTensor):
44 init_mode_local_shape = self.param.init_mode.local_shape
45 else:
46 init_mode_local_shape = self.param.init_mode.shape
47 init_shape = list(init_mode_local_shape)
48 init_shape[0] = init_shape[0] // self.shard_size
49 if isinstance(self.param.init_mode, DTensor):
50 # 'self.param.to_local()' and 'self.param.init_mode.to_local()' is same object.
51 self.param.init_mode.to_local().shape = init_shape
52 else:
53 # 'self.param' and 'self.param.init_mode' is not same object. set 'self.param.shape' manually.
54 self.param.init_mode.shape = init_shape
55 self.param.shape = init_shape
56 self.param.hsdp_init_index = data_slice_index
58 # TODO: refactor, mindspore jit ast does not need this
59 self.sharded_param = Parameter(initializer("zeros", init_shape, self.param.dtype),
60 name="sharded_"+self.param.name,
61 requires_grad=self.param.requires_grad)
63 def _init_unsharded_param(self):
64 return