Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / init_weights.py: 64%
11 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Empty / on-device weight initialization utilities."""
17from contextlib import contextmanager
18from hyper_parallel.platform import get_platform
20platform = get_platform()
23@contextmanager
24def init_empty_weights(include_buffers=False):
25 """Context manager that initializes model parameters (and optionally
26 buffers) on the meta device.
28 Models created under this context have no real data, which avoids allocating
29 any NPU/GPU/CPU memory. This is useful when the model is too large to fit in RAM
30 and the weights will be loaded from a checkpoint afterwards.
32 Args:
33 include_buffers (bool): If ``True``, buffers are also placed on the meta
34 device. Defaults to ``False``.
36 Example::
38 from hyper_parallel import init_empty_weights
40 with init_empty_weights():
41 model = MyLargeModel() # no memory allocated
43 Note:
44 A model initialised under this context has **no weights**. You cannot
45 call ``model.to(some_device)`` directly; load a checkpoint first.
46 """
47 with platform.init_on_device(platform.meta_device, include_buffers=include_buffers):
48 yield
51@contextmanager
52def init_on_device(device, include_buffers=False):
53 """Context manager that initializes model parameters (and optionally buffers) on *device*.
55 Args:
56 device: Target device for parameter (and buffer) allocation.
57 include_buffers (bool): If ``True``, buffers are also placed on *device*.
58 Defaults to ``False``.
60 Example::
62 from hyper_parallel import init_on_device
63 import torch
65 with init_on_device(torch.device("npu")):
66 model = MyModel() # parameters live on Ascend
67 """
68 with platform.init_on_device(device, include_buffers=include_buffers):
69 yield