Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / hook_function.py: 21%

38 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""add post backward hook function""" 

16import torch 

17 

18 

19class PostBackwardFunction(torch.autograd.Function): 

20 """Post backward hook function""" 

21 

22 @staticmethod 

23 def forward(ctx, hsdp_scheduler, *inputs): 

24 """Save the scheduler reference and pass inputs through unchanged.""" 

25 ctx.hsdp_scheduler = hsdp_scheduler 

26 return inputs 

27 

28 @staticmethod 

29 def backward(ctx, *grads): 

30 """Trigger the scheduler's backward hook and pass gradients through unchanged.""" 

31 # pylint: disable=W0212 

32 ctx.hsdp_scheduler._backward_hook() 

33 return (None,) + grads 

34 

35 @classmethod 

36 def apply(cls, *args, **kwargs): 

37 """Override apply function to handle DTensor inputs""" 

38 # pylint: disable=C0415 

39 from hyper_parallel import DTensor 

40 

41 input_args = [] 

42 input_layouts = [] 

43 for arg in args: 

44 if arg is None: 

45 input_layouts.append(None) 

46 input_args.append(arg) 

47 continue 

48 if not hasattr(arg, "_layout"): 

49 input_layouts.append(None) 

50 input_args.append(arg) 

51 else: 

52 layout = arg.layout 

53 input_layouts.append(layout) 

54 input_args.append(arg.to_local()) 

55 

56 origin_output = super().apply(*input_args, **kwargs) 

57 

58 if len(origin_output) != len(input_args) - 1: 

59 raise RuntimeError("number of output should equal to number of input minus 1") 

60 

61 if isinstance(origin_output, (tuple, list)): 

62 output = () 

63 for i, output_item in enumerate(origin_output): 

64 item_layout = input_layouts[i+1] 

65 if item_layout is None: 

66 output += (output_item,) 

67 else: 

68 

69 output += (DTensor.from_local(output_item, item_layout.mesh, item_layout.alias_placements),) 

70 return output 

71 return origin_output