Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / parameter_init.py: 3%
32 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Parameter init"""
18def init_parameters(cell, stage_index=0):
19 r"""
20 init parameters.
22 Args:
23 cell(Cell): The cell to init parameters.
24 stage_index: stage index for init.
25 Raises:
26 ValueError: If the `cell` is not a cell.
27 """
28 import mindspore as ms # pylint: disable=import-outside-toplevel
29 from mindspore.nn.cell import Cell # pylint: disable=import-outside-toplevel
30 from mindspore.parallel._tensor import _get_slice_index # pylint: disable=import-outside-toplevel
31 from hyper_parallel import DTensor # pylint: disable=import-outside-toplevel
32 if not isinstance(cell, Cell):
33 raise ValueError(f"cell's type must be Cell but got {type(cell)}.")
34 if not isinstance(stage_index, int):
35 raise ValueError(f"stage_index's type must be int but got {type(stage_index)}.")
36 for param in cell.get_parameters(expand=True):
37 param_is_dtensor = isinstance(param, DTensor)
38 if not param.has_init:
39 continue
40 data_slice_index = None
41 if hasattr(param, "hsdp_init_index"):
42 data_slice_index = param.hsdp_init_index
43 elif param_is_dtensor and param.layout is not None:
44 data_slice_index = _get_slice_index(param.layout.mesh_shape, param.layout.tensor_map, None)
45 local_shape = param.shape
46 init_tensor = param.init_mode
47 if param_is_dtensor:
48 local_shape = param.local_shape
49 init_tensor = param.init_mode.to_local()
50 if isinstance(init_tensor, ms.Parameter):
51 init_tensor = init_tensor.init_mode
53 if data_slice_index is not None:
54 init_data = init_tensor.init_data(slice_index=int(data_slice_index) + stage_index, shape=local_shape)
55 else:
56 init_data = init_tensor.init_data(shape=local_shape)
57 param.init_mode = None
58 param.init = None
59 param.set_data(init_data)
60 return cell