Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / init_weights.py: 26%
50 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MindSpore implementation of on-device weight initialization."""
17from contextlib import contextmanager
19import mindspore as ms
20from mindspore import mint, nn
22from hyper_parallel.platform.mindspore.utils import normalize_runtime_device
25def _cell_to_empty(self, device=None, recurse=True):
26 """Patch for ``nn.Cell.to_empty`` (init.md): ``param.set_data(mint.empty_like(...))``.
28 Walks ``parameters_and_names(expand=...)`` — ``recurse`` maps to MindSpore's
29 ``expand`` flag. ``device`` defaults from ``device_target`` when omitted so
30 ``net.to_empty()`` works without arguments.
31 """
32 # pylint: disable=import-outside-toplevel
33 from hyper_parallel.core.dtensor.dtensor import DTensor
35 if device is None:
36 device = ms.get_context("device_target")
37 for _, param in self.parameters_and_names(expand=recurse):
38 if param is None:
39 continue
40 if isinstance(param, DTensor):
41 local = param.to_local()
42 new_tensor = mint.empty_like(local, device=device)
43 param.set_data(new_tensor)
44 continue
45 new_tensor = mint.empty_like(param, device=device)
46 param.set_data(new_tensor)
47 return self
50def _install_cell_to_empty_patch():
51 if getattr(nn.Cell, "_hyper_parallel_to_empty_installed", False):
52 return
53 nn.Cell.to_empty = _cell_to_empty
54 nn.Cell._hyper_parallel_to_empty_installed = True # pylint: disable=W0212
57def _check_valid_init_device(device: str):
58 """Validate user-provided init device for MindSpore init_on_device."""
59 if device not in {"npu", "cpu", "meta"}:
60 raise ValueError(f'Unsupported device "{device}", only "npu", "cpu", and "meta" are allowed.')
63@contextmanager
64def init_on_device(device, include_buffers=False):
65 """Context manager that initializes model parameters (and optionally
66 buffers) on *device*.
68 Args:
69 device: ``"meta"`` to skip allocation (no real memory used), or a
70 real device string (e.g. ``"cpu"``, ``"npu"``) for placement.
71 include_buffers (bool): Also redirect buffers to *device*.
72 """
73 if include_buffers:
74 raise ValueError("MindSpore platform does not support include_buffers=True.")
75 _check_valid_init_device(device)
76 orig_insert_param = nn.Cell.insert_param_to_cell
78 # pylint: disable=W0212
79 def _insert_param_to_cell(module, param_name, param, check_name_contain_dot=True):
80 orig_insert_param(module, param_name, param, check_name_contain_dot)
81 if param is not None:
82 def _custom_kwargs(param_obj):
83 ms_graph_attrs = [
84 "init_mode", "is_default_input_init", "_param_info", "is_init", "_inited_param", "_sliced",
85 "requires_aggr", "_cast_type", "_unique", "is_in_parallel", "_pipeline_stage_list", "load",
86 ]
87 return {k: v for k, v in param_obj.__dict__.items() if k not in ms_graph_attrs}
88 orig_param = module._params[param_name]
89 # get custom kwargs from orig_param, do not change the original param __dict__
90 kwargs = _custom_kwargs(orig_param)
91 kwargs["name"] = param.name
92 kwargs["requires_grad"] = param.requires_grad
93 param_cls = type(orig_param)
94 param_device = normalize_runtime_device(param.device)
95 module._params[param_name] = (param if param_device == device
96 else param_cls(orig_param.to(device=device), **kwargs))
98 try:
99 nn.Cell.insert_param_to_cell = _insert_param_to_cell
100 yield
101 finally:
102 nn.Cell.insert_param_to_cell = orig_insert_param