Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / hook_function.py: 21%
38 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""add post backward hook function"""
16import torch
19class PostBackwardFunction(torch.autograd.Function):
20 """Post backward hook function"""
22 @staticmethod
23 def forward(ctx, hsdp_scheduler, *inputs):
24 """Save the scheduler reference and pass inputs through unchanged."""
25 ctx.hsdp_scheduler = hsdp_scheduler
26 return inputs
28 @staticmethod
29 def backward(ctx, *grads):
30 """Trigger the scheduler's backward hook and pass gradients through unchanged."""
31 # pylint: disable=W0212
32 ctx.hsdp_scheduler._backward_hook()
33 return (None,) + grads
35 @classmethod
36 def apply(cls, *args, **kwargs):
37 """Override apply function to handle DTensor inputs"""
38 # pylint: disable=C0415
39 from hyper_parallel import DTensor
41 input_args = []
42 input_layouts = []
43 for arg in args:
44 if arg is None:
45 input_layouts.append(None)
46 input_args.append(arg)
47 continue
48 if not hasattr(arg, "_layout"):
49 input_layouts.append(None)
50 input_args.append(arg)
51 else:
52 layout = arg.layout
53 input_layouts.append(layout)
54 input_args.append(arg.to_local())
56 origin_output = super().apply(*input_args, **kwargs)
58 if len(origin_output) != len(input_args) - 1:
59 raise RuntimeError("number of output should equal to number of input minus 1")
61 if isinstance(origin_output, (tuple, list)):
62 output = ()
63 for i, output_item in enumerate(origin_output):
64 item_layout = input_layouts[i+1]
65 if item_layout is None:
66 output += (output_item,)
67 else:
69 output += (DTensor.from_local(output_item, item_layout.mesh, item_layout.alias_placements),)
70 return output
71 return origin_output