Coverage for hyper_parallel / platform / mindspore / parameter_init.py: 91%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Parameter init""" 

16 

17def init_parameters(cell, stage_index=0): 

18 r""" 

19 init parameters. 

20 

21 Args: 

22 cell(Cell): The cell to init parameters. 

23 stage_index: stage index for init. 

24 Raises: 

25 ValueError: If the `cell` is not a cell. 

26 """ 

27 import mindspore as ms # pylint: disable=import-outside-toplevel 

28 from mindspore.nn.cell import Cell # pylint: disable=import-outside-toplevel 

29 from mindspore.parallel._tensor import _get_slice_index # pylint: disable=import-outside-toplevel 

30 from hyper_parallel import DTensor # pylint: disable=import-outside-toplevel 

31 if not isinstance(cell, Cell): 

32 raise ValueError(f"cell's type must be Cell but got {type(cell)}.") 

33 if not isinstance(stage_index, int): 

34 raise ValueError(f"stage_index's type must be int but got {type(stage_index)}.") 

35 for param in cell.get_parameters(expand=True): 

36 param_is_dtensor = isinstance(param, DTensor) 

37 if not param.has_init: 

38 continue 

39 data_slice_index = None 

40 if hasattr(param, "hsdp_init_index"): 

41 data_slice_index = param.hsdp_init_index 

42 elif param_is_dtensor and param.layout is not None: 

43 data_slice_index = _get_slice_index(param.layout.mesh_shape, param.layout.tensor_map, None) 

44 local_shape = param.shape 

45 init_tensor = param.init_mode 

46 if param_is_dtensor: 

47 local_shape = param.local_shape 

48 init_tensor = param.init_mode.to_local() 

49 if isinstance(init_tensor, ms.Parameter): 

50 init_tensor = init_tensor.init_mode 

51 

52 if data_slice_index is not None: 

53 init_data = init_tensor.init_data(slice_index=int(data_slice_index) + stage_index, shape=local_shape) 

54 else: 

55 init_data = init_tensor.init_data(shape=local_shape) 

56 param.init_mode = None 

57 param.init = None 

58 param.set_data(init_data) 

59 return cell