Coverage for hyper_parallel / platform / torch / function_override.py: 73%
44 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch function override"""
16from torch import Tensor
17from torch.nn.modules import module
18from torch.nn.modules import _functions
19from torch.nn.modules._functions import BackwardHookFunction
20from torch.utils.hooks import BackwardHook
21from torch.utils._pytree import tree_flatten, tree_unflatten
24class DTensorBackwardHookFunction(BackwardHookFunction):
25 """override BackwardHookFunction for dtensor"""
27 @classmethod
28 def apply(cls, *args, **kwargs):
29 """Override apply function for dtensor."""
30 # pylint: disable=C0415
31 from hyper_parallel import DTensor
33 input_args = []
34 input_layouts = []
36 for arg in args:
37 if arg is None:
38 input_layouts.append(None)
39 input_args.append(arg)
40 continue
42 if not hasattr(arg, "_layout"):
43 input_layouts.append(None)
44 input_args.append(arg)
45 else:
46 layout = arg.layout
47 input_layouts.append(layout)
48 input_args.append(arg.to_local())
50 origin_output = BackwardHookFunction.apply(*input_args, **kwargs)
52 if len(origin_output) != len(input_args):
53 raise RuntimeError("number of output should equal to number of input")
55 if isinstance(origin_output, (tuple, list)):
56 output = ()
57 for i, output_item in enumerate(origin_output):
58 if input_layouts[i] is None:
59 output += (output_item,)
60 else:
61 output += (DTensor.from_local(output_item, input_layouts[i].mesh, input_layouts[i].placements),)
62 return output
63 return origin_output
66class ExtendBackwardHook(BackwardHook):
67 """Override BackwardHook for none tuple inputs."""
69 def setup_output_hook(self, args):
70 if not isinstance(args, tuple) and not isinstance(args, Tensor):
71 arg_list, args_spec = tree_flatten(args)
72 arg_list = super().setup_output_hook(tuple(arg_list))
73 return tree_unflatten(arg_list, args_spec)
74 return super().setup_output_hook(args)
77def override_functions():
78 _functions.BackwardHookFunction = DTensorBackwardHookFunction
79 module.BackwardHook = ExtendBackwardHook