Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / function_override.py: 32%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch function override""" 

16from torch import Tensor 

17from torch.nn.modules import module 

18from torch.nn.modules import _functions 

19from torch.nn.modules._functions import BackwardHookFunction 

20from torch.utils.hooks import BackwardHook 

21from torch.utils._pytree import tree_flatten, tree_unflatten 

22 

23 

24class DTensorBackwardHookFunction(BackwardHookFunction): 

25 """override BackwardHookFunction for dtensor""" 

26 

27 @classmethod 

28 def apply(cls, *args, **kwargs): 

29 """Override apply function for dtensor.""" 

30 # pylint: disable=C0415 

31 from hyper_parallel import DTensor 

32 

33 input_args = [] 

34 input_layouts = [] 

35 

36 for arg in args: 

37 if arg is None: 

38 input_layouts.append(None) 

39 input_args.append(arg) 

40 continue 

41 

42 if not hasattr(arg, "_layout"): 

43 input_layouts.append(None) 

44 input_args.append(arg) 

45 else: 

46 layout = arg.layout 

47 input_layouts.append(layout) 

48 input_args.append(arg.to_local()) 

49 

50 origin_output = BackwardHookFunction.apply(*input_args, **kwargs) 

51 

52 if len(origin_output) != len(input_args): 

53 raise RuntimeError("number of output should equal to number of input") 

54 

55 if isinstance(origin_output, (tuple, list)): 

56 output = () 

57 for i, output_item in enumerate(origin_output): 

58 if input_layouts[i] is None: 

59 output += (output_item,) 

60 else: 

61 output += (DTensor.from_local( 

62 output_item, input_layouts[i].mesh, 

63 input_layouts[i].alias_placements),) 

64 return output 

65 return origin_output 

66 

67 

68class ExtendBackwardHook(BackwardHook): 

69 """Override BackwardHook for none tuple inputs.""" 

70 

71 def setup_output_hook(self, args): 

72 if not isinstance(args, tuple) and not isinstance(args, Tensor): 

73 arg_list, args_spec = tree_flatten(args) 

74 arg_list = super().setup_output_hook(tuple(arg_list)) 

75 return tree_unflatten(arg_list, args_spec) 

76 return super().setup_output_hook(args) 

77 

78 

79def override_functions(): 

80 _functions.BackwardHookFunction = DTensorBackwardHookFunction 

81 module.BackwardHook = ExtendBackwardHook