Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / init_weights.py: 26%

50 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore implementation of on-device weight initialization.""" 

16 

17from contextlib import contextmanager 

18 

19import mindspore as ms 

20from mindspore import mint, nn 

21 

22from hyper_parallel.platform.mindspore.utils import normalize_runtime_device 

23 

24 

25def _cell_to_empty(self, device=None, recurse=True): 

26 """Patch for ``nn.Cell.to_empty`` (init.md): ``param.set_data(mint.empty_like(...))``. 

27 

28 Walks ``parameters_and_names(expand=...)`` — ``recurse`` maps to MindSpore's 

29 ``expand`` flag. ``device`` defaults from ``device_target`` when omitted so 

30 ``net.to_empty()`` works without arguments. 

31 """ 

32 # pylint: disable=import-outside-toplevel 

33 from hyper_parallel.core.dtensor.dtensor import DTensor 

34 

35 if device is None: 

36 device = ms.get_context("device_target") 

37 for _, param in self.parameters_and_names(expand=recurse): 

38 if param is None: 

39 continue 

40 if isinstance(param, DTensor): 

41 local = param.to_local() 

42 new_tensor = mint.empty_like(local, device=device) 

43 param.set_data(new_tensor) 

44 continue 

45 new_tensor = mint.empty_like(param, device=device) 

46 param.set_data(new_tensor) 

47 return self 

48 

49 

50def _install_cell_to_empty_patch(): 

51 if getattr(nn.Cell, "_hyper_parallel_to_empty_installed", False): 

52 return 

53 nn.Cell.to_empty = _cell_to_empty 

54 nn.Cell._hyper_parallel_to_empty_installed = True # pylint: disable=W0212 

55 

56 

57def _check_valid_init_device(device: str): 

58 """Validate user-provided init device for MindSpore init_on_device.""" 

59 if device not in {"npu", "cpu", "meta"}: 

60 raise ValueError(f'Unsupported device "{device}", only "npu", "cpu", and "meta" are allowed.') 

61 

62 

63@contextmanager 

64def init_on_device(device, include_buffers=False): 

65 """Context manager that initializes model parameters (and optionally 

66 buffers) on *device*. 

67 

68 Args: 

69 device: ``"meta"`` to skip allocation (no real memory used), or a 

70 real device string (e.g. ``"cpu"``, ``"npu"``) for placement. 

71 include_buffers (bool): Also redirect buffers to *device*. 

72 """ 

73 if include_buffers: 

74 raise ValueError("MindSpore platform does not support include_buffers=True.") 

75 _check_valid_init_device(device) 

76 orig_insert_param = nn.Cell.insert_param_to_cell 

77 

78 # pylint: disable=W0212 

79 def _insert_param_to_cell(module, param_name, param, check_name_contain_dot=True): 

80 orig_insert_param(module, param_name, param, check_name_contain_dot) 

81 if param is not None: 

82 def _custom_kwargs(param_obj): 

83 ms_graph_attrs = [ 

84 "init_mode", "is_default_input_init", "_param_info", "is_init", "_inited_param", "_sliced", 

85 "requires_aggr", "_cast_type", "_unique", "is_in_parallel", "_pipeline_stage_list", "load", 

86 ] 

87 return {k: v for k, v in param_obj.__dict__.items() if k not in ms_graph_attrs} 

88 orig_param = module._params[param_name] 

89 # get custom kwargs from orig_param, do not change the original param __dict__ 

90 kwargs = _custom_kwargs(orig_param) 

91 kwargs["name"] = param.name 

92 kwargs["requires_grad"] = param.requires_grad 

93 param_cls = type(orig_param) 

94 param_device = normalize_runtime_device(param.device) 

95 module._params[param_name] = (param if param_device == device 

96 else param_cls(orig_param.to(device=device), **kwargs)) 

97 

98 try: 

99 nn.Cell.insert_param_to_cell = _insert_param_to_cell 

100 yield 

101 finally: 

102 nn.Cell.insert_param_to_cell = orig_insert_param