Coverage for hyper_parallel / platform / torch / fully_shard / hook_function.py: 71%

38 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""add post backward hook function""" 

16import torch 

17 

18 

19class PostBackwardFunction(torch.autograd.Function): 

20 """Post backward hook function""" 

21 

22 @staticmethod 

23 def forward(ctx, hsdp_scheduler, *inputs): 

24 ctx.hsdp_scheduler = hsdp_scheduler 

25 return inputs 

26 

27 @staticmethod 

28 def backward(ctx, *grads): 

29 # pylint: disable=W0212 

30 ctx.hsdp_scheduler._backward_hook() 

31 return (None,) + grads 

32 

33 @classmethod 

34 def apply(cls, *args, **kwargs): 

35 """Override apply function to handle DTensor inputs""" 

36 # pylint: disable=C0415 

37 from hyper_parallel import DTensor 

38 

39 input_args = [] 

40 input_layouts = [] 

41 for arg in args: 

42 if arg is None: 

43 input_layouts.append(None) 

44 input_args.append(arg) 

45 continue 

46 if not hasattr(arg, "_layout"): 

47 input_layouts.append(None) 

48 input_args.append(arg) 

49 else: 

50 layout = arg.layout 

51 input_layouts.append(layout) 

52 input_args.append(arg.to_local()) 

53 

54 origin_output = super().apply(*input_args, **kwargs) 

55 

56 if len(origin_output) != len(input_args) - 1: 

57 raise RuntimeError("number of output should equal to number of input minus 1") 

58 

59 if isinstance(origin_output, (tuple, list)): 

60 output = () 

61 for i, output_item in enumerate(origin_output): 

62 item_layout = input_layouts[i+1] 

63 if item_layout is None: 

64 output += (output_item,) 

65 else: 

66 

67 output += (DTensor.from_local(output_item, item_layout.mesh, item_layout.placements),) 

68 return output 

69 return origin_output