Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / function_override.py: 32%
44 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch function override"""
16from torch import Tensor
17from torch.nn.modules import module
18from torch.nn.modules import _functions
19from torch.nn.modules._functions import BackwardHookFunction
20from torch.utils.hooks import BackwardHook
21from torch.utils._pytree import tree_flatten, tree_unflatten
24class DTensorBackwardHookFunction(BackwardHookFunction):
25 """override BackwardHookFunction for dtensor"""
27 @classmethod
28 def apply(cls, *args, **kwargs):
29 """Override apply function for dtensor."""
30 # pylint: disable=C0415
31 from hyper_parallel import DTensor
33 input_args = []
34 input_layouts = []
36 for arg in args:
37 if arg is None:
38 input_layouts.append(None)
39 input_args.append(arg)
40 continue
42 if not hasattr(arg, "_layout"):
43 input_layouts.append(None)
44 input_args.append(arg)
45 else:
46 layout = arg.layout
47 input_layouts.append(layout)
48 input_args.append(arg.to_local())
50 origin_output = BackwardHookFunction.apply(*input_args, **kwargs)
52 if len(origin_output) != len(input_args):
53 raise RuntimeError("number of output should equal to number of input")
55 if isinstance(origin_output, (tuple, list)):
56 output = ()
57 for i, output_item in enumerate(origin_output):
58 if input_layouts[i] is None:
59 output += (output_item,)
60 else:
61 output += (DTensor.from_local(
62 output_item, input_layouts[i].mesh,
63 input_layouts[i].alias_placements),)
64 return output
65 return origin_output
68class ExtendBackwardHook(BackwardHook):
69 """Override BackwardHook for none tuple inputs."""
71 def setup_output_hook(self, args):
72 if not isinstance(args, tuple) and not isinstance(args, Tensor):
73 arg_list, args_spec = tree_flatten(args)
74 arg_list = super().setup_output_hook(tuple(arg_list))
75 return tree_unflatten(arg_list, args_spec)
76 return super().setup_output_hook(args)
79def override_functions():
80 _functions.BackwardHookFunction = DTensorBackwardHookFunction
81 module.BackwardHook = ExtendBackwardHook