Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / init_weights.py: 64%

11 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Empty / on-device weight initialization utilities.""" 

16 

17from contextlib import contextmanager 

18from hyper_parallel.platform import get_platform 

19 

20platform = get_platform() 

21 

22 

23@contextmanager 

24def init_empty_weights(include_buffers=False): 

25 """Context manager that initializes model parameters (and optionally 

26 buffers) on the meta device. 

27 

28 Models created under this context have no real data, which avoids allocating 

29 any NPU/GPU/CPU memory. This is useful when the model is too large to fit in RAM 

30 and the weights will be loaded from a checkpoint afterwards. 

31 

32 Args: 

33 include_buffers (bool): If ``True``, buffers are also placed on the meta 

34 device. Defaults to ``False``. 

35 

36 Example:: 

37 

38 from hyper_parallel import init_empty_weights 

39 

40 with init_empty_weights(): 

41 model = MyLargeModel() # no memory allocated 

42 

43 Note: 

44 A model initialised under this context has **no weights**. You cannot 

45 call ``model.to(some_device)`` directly; load a checkpoint first. 

46 """ 

47 with platform.init_on_device(platform.meta_device, include_buffers=include_buffers): 

48 yield 

49 

50 

51@contextmanager 

52def init_on_device(device, include_buffers=False): 

53 """Context manager that initializes model parameters (and optionally buffers) on *device*. 

54 

55 Args: 

56 device: Target device for parameter (and buffer) allocation. 

57 include_buffers (bool): If ``True``, buffers are also placed on *device*. 

58 Defaults to ``False``. 

59 

60 Example:: 

61 

62 from hyper_parallel import init_on_device 

63 import torch 

64 

65 with init_on_device(torch.device("npu")): 

66 model = MyModel() # parameters live on Ascend 

67 """ 

68 with platform.init_on_device(device, include_buffers=include_buffers): 

69 yield