Coverage for hyper_parallel / platform / mindspore / parameter_init.py: 91%
32 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Parameter init"""
17def init_parameters(cell, stage_index=0):
18 r"""
19 init parameters.
21 Args:
22 cell(Cell): The cell to init parameters.
23 stage_index: stage index for init.
24 Raises:
25 ValueError: If the `cell` is not a cell.
26 """
27 import mindspore as ms # pylint: disable=import-outside-toplevel
28 from mindspore.nn.cell import Cell # pylint: disable=import-outside-toplevel
29 from mindspore.parallel._tensor import _get_slice_index # pylint: disable=import-outside-toplevel
30 from hyper_parallel import DTensor # pylint: disable=import-outside-toplevel
31 if not isinstance(cell, Cell):
32 raise ValueError(f"cell's type must be Cell but got {type(cell)}.")
33 if not isinstance(stage_index, int):
34 raise ValueError(f"stage_index's type must be int but got {type(stage_index)}.")
35 for param in cell.get_parameters(expand=True):
36 param_is_dtensor = isinstance(param, DTensor)
37 if not param.has_init:
38 continue
39 data_slice_index = None
40 if hasattr(param, "hsdp_init_index"):
41 data_slice_index = param.hsdp_init_index
42 elif param_is_dtensor and param.layout is not None:
43 data_slice_index = _get_slice_index(param.layout.mesh_shape, param.layout.tensor_map, None)
44 local_shape = param.shape
45 init_tensor = param.init_mode
46 if param_is_dtensor:
47 local_shape = param.local_shape
48 init_tensor = param.init_mode.to_local()
49 if isinstance(init_tensor, ms.Parameter):
50 init_tensor = init_tensor.init_mode
52 if data_slice_index is not None:
53 init_data = init_tensor.init_data(slice_index=int(data_slice_index) + stage_index, shape=local_shape)
54 else:
55 init_data = init_tensor.init_data(shape=local_shape)
56 param.init_mode = None
57 param.init = None
58 param.set_data(init_data)
59 return cell