Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_cell_backward_hook.py: 88%
16 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""CellBackwardHook distributed op."""
17from .parallel_tuple_elementwise import TupleElementWiseDistributedOp
20class CellBackwardHookDistributedOp(TupleElementWiseDistributedOp):
21 """Distributed op for MindSpore CellBackwardHook.
23 Hook outputs may contain a mix of DTensor-backed slots and plain local
24 Tensor slots. The local slots should be passed through unchanged.
25 """
27 def wrap_output(self, py_output, output_layouts):
28 """Wrap outputs while preserving local Tensor slots."""
29 # pylint: disable=C0415
30 from hyper_parallel.core.dtensor.dtensor import DTensor
32 if isinstance(py_output, (tuple, list)):
33 if len(py_output) != len(output_layouts):
34 raise RuntimeError(
35 f"Output tuple size ({len(py_output)}) "
36 f"does not match layout tuple size ({len(output_layouts)})")
37 output = ()
38 for item, layout in zip(py_output, output_layouts):
39 if layout is None:
40 output += (item,)
41 else:
42 output += (DTensor.from_local(item, layout.mesh, layout.alias_placements),)
43 return output
45 if output_layouts[0] is None:
46 return py_output
47 return DTensor.from_local(
48 py_output, output_layouts[0].mesh, output_layouts[0].alias_placements
49 )