Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / init_weights.py: 15%

27 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""PyTorch implementation of on-device weight initialization.""" 

16 

17from contextlib import contextmanager 

18 

19from torch import nn 

20 

21 

22@contextmanager 

23def init_on_device(device, include_buffers=False): 

24 """Monkey-patch ``nn.Module`` so that every parameter (and optionally every 

25 buffer) is placed on *device* at registration time. 

26 

27 Args: 

28 device (torch.device): Target device. 

29 include_buffers (bool): Also redirect buffers to *device*. 

30 """ 

31 orig_register_parameter = nn.Module.register_parameter 

32 orig_register_buffer = nn.Module.register_buffer 

33 

34 # pylint: disable=W0212 

35 def _register_parameter(module, name, param): 

36 orig_register_parameter(module, name, param) 

37 if param is not None: 

38 orig_param = module._parameters[name] 

39 param_cls = type(orig_param) 

40 kwargs = orig_param.__dict__ 

41 new_param = (param if param.device == device 

42 else param_cls(orig_param.to(device), **kwargs)) 

43 new_param.requires_grad = param.requires_grad 

44 module._parameters[name] = new_param 

45 

46 # pylint: disable=W0212 

47 def _register_buffer(module, name, buffer, persistent=True): 

48 orig_register_buffer(module, name, buffer, persistent=persistent) 

49 if buffer is not None: 

50 module._buffers[name] = module._buffers[name].to(device) 

51 

52 try: 

53 nn.Module.register_parameter = _register_parameter 

54 if include_buffers: 

55 nn.Module.register_buffer = _register_buffer 

56 yield 

57 finally: 

58 nn.Module.register_parameter = orig_register_parameter 

59 if include_buffers: 

60 nn.Module.register_buffer = orig_register_buffer