Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / parameter_init.py: 3%

32 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Parameter init""" 

16 

17 

18def init_parameters(cell, stage_index=0): 

19 r""" 

20 init parameters. 

21 

22 Args: 

23 cell(Cell): The cell to init parameters. 

24 stage_index: stage index for init. 

25 Raises: 

26 ValueError: If the `cell` is not a cell. 

27 """ 

28 import mindspore as ms # pylint: disable=import-outside-toplevel 

29 from mindspore.nn.cell import Cell # pylint: disable=import-outside-toplevel 

30 from mindspore.parallel._tensor import _get_slice_index # pylint: disable=import-outside-toplevel 

31 from hyper_parallel import DTensor # pylint: disable=import-outside-toplevel 

32 if not isinstance(cell, Cell): 

33 raise ValueError(f"cell's type must be Cell but got {type(cell)}.") 

34 if not isinstance(stage_index, int): 

35 raise ValueError(f"stage_index's type must be int but got {type(stage_index)}.") 

36 for param in cell.get_parameters(expand=True): 

37 param_is_dtensor = isinstance(param, DTensor) 

38 if not param.has_init: 

39 continue 

40 data_slice_index = None 

41 if hasattr(param, "hsdp_init_index"): 

42 data_slice_index = param.hsdp_init_index 

43 elif param_is_dtensor and param.layout is not None: 

44 data_slice_index = _get_slice_index(param.layout.mesh_shape, param.layout.tensor_map, None) 

45 local_shape = param.shape 

46 init_tensor = param.init_mode 

47 if param_is_dtensor: 

48 local_shape = param.local_shape 

49 init_tensor = param.init_mode.to_local() 

50 if isinstance(init_tensor, ms.Parameter): 

51 init_tensor = init_tensor.init_mode 

52 

53 if data_slice_index is not None: 

54 init_data = init_tensor.init_data(slice_index=int(data_slice_index) + stage_index, shape=local_shape) 

55 else: 

56 init_data = init_tensor.init_data(shape=local_shape) 

57 param.init_mode = None 

58 param.init = None 

59 param.set_data(init_data) 

60 return cell