Coverage for hyper_parallel / platform / torch / function_override.py: 73%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch function override""" 

16from torch import Tensor 

17from torch.nn.modules import module 

18from torch.nn.modules import _functions 

19from torch.nn.modules._functions import BackwardHookFunction 

20from torch.utils.hooks import BackwardHook 

21from torch.utils._pytree import tree_flatten, tree_unflatten 

22 

23 

24class DTensorBackwardHookFunction(BackwardHookFunction): 

25 """override BackwardHookFunction for dtensor""" 

26 

27 @classmethod 

28 def apply(cls, *args, **kwargs): 

29 """Override apply function for dtensor.""" 

30 # pylint: disable=C0415 

31 from hyper_parallel import DTensor 

32 

33 input_args = [] 

34 input_layouts = [] 

35 

36 for arg in args: 

37 if arg is None: 

38 input_layouts.append(None) 

39 input_args.append(arg) 

40 continue 

41 

42 if not hasattr(arg, "_layout"): 

43 input_layouts.append(None) 

44 input_args.append(arg) 

45 else: 

46 layout = arg.layout 

47 input_layouts.append(layout) 

48 input_args.append(arg.to_local()) 

49 

50 origin_output = BackwardHookFunction.apply(*input_args, **kwargs) 

51 

52 if len(origin_output) != len(input_args): 

53 raise RuntimeError("number of output should equal to number of input") 

54 

55 if isinstance(origin_output, (tuple, list)): 

56 output = () 

57 for i, output_item in enumerate(origin_output): 

58 if input_layouts[i] is None: 

59 output += (output_item,) 

60 else: 

61 output += (DTensor.from_local(output_item, input_layouts[i].mesh, input_layouts[i].placements),) 

62 return output 

63 return origin_output 

64 

65 

66class ExtendBackwardHook(BackwardHook): 

67 """Override BackwardHook for none tuple inputs.""" 

68 

69 def setup_output_hook(self, args): 

70 if not isinstance(args, tuple) and not isinstance(args, Tensor): 

71 arg_list, args_spec = tree_flatten(args) 

72 arg_list = super().setup_output_hook(tuple(arg_list)) 

73 return tree_unflatten(arg_list, args_spec) 

74 return super().setup_output_hook(args) 

75 

76 

77def override_functions(): 

78 _functions.BackwardHookFunction = DTensorBackwardHookFunction 

79 module.BackwardHook = ExtendBackwardHook