Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / init_weights.py: 15%
27 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""PyTorch implementation of on-device weight initialization."""
17from contextlib import contextmanager
19from torch import nn
22@contextmanager
23def init_on_device(device, include_buffers=False):
24 """Monkey-patch ``nn.Module`` so that every parameter (and optionally every
25 buffer) is placed on *device* at registration time.
27 Args:
28 device (torch.device): Target device.
29 include_buffers (bool): Also redirect buffers to *device*.
30 """
31 orig_register_parameter = nn.Module.register_parameter
32 orig_register_buffer = nn.Module.register_buffer
34 # pylint: disable=W0212
35 def _register_parameter(module, name, param):
36 orig_register_parameter(module, name, param)
37 if param is not None:
38 orig_param = module._parameters[name]
39 param_cls = type(orig_param)
40 kwargs = orig_param.__dict__
41 new_param = (param if param.device == device
42 else param_cls(orig_param.to(device), **kwargs))
43 new_param.requires_grad = param.requires_grad
44 module._parameters[name] = new_param
46 # pylint: disable=W0212
47 def _register_buffer(module, name, buffer, persistent=True):
48 orig_register_buffer(module, name, buffer, persistent=persistent)
49 if buffer is not None:
50 module._buffers[name] = module._buffers[name].to(device)
52 try:
53 nn.Module.register_parameter = _register_parameter
54 if include_buffers:
55 nn.Module.register_buffer = _register_buffer
56 yield
57 finally:
58 nn.Module.register_parameter = orig_register_parameter
59 if include_buffers:
60 nn.Module.register_buffer = orig_register_buffer