Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_cell_backward_hook.py: 88%

16 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""CellBackwardHook distributed op.""" 

16 

17from .parallel_tuple_elementwise import TupleElementWiseDistributedOp 

18 

19 

20class CellBackwardHookDistributedOp(TupleElementWiseDistributedOp): 

21 """Distributed op for MindSpore CellBackwardHook. 

22 

23 Hook outputs may contain a mix of DTensor-backed slots and plain local 

24 Tensor slots. The local slots should be passed through unchanged. 

25 """ 

26 

27 def wrap_output(self, py_output, output_layouts): 

28 """Wrap outputs while preserving local Tensor slots.""" 

29 # pylint: disable=C0415 

30 from hyper_parallel.core.dtensor.dtensor import DTensor 

31 

32 if isinstance(py_output, (tuple, list)): 

33 if len(py_output) != len(output_layouts): 

34 raise RuntimeError( 

35 f"Output tuple size ({len(py_output)}) " 

36 f"does not match layout tuple size ({len(output_layouts)})") 

37 output = () 

38 for item, layout in zip(py_output, output_layouts): 

39 if layout is None: 

40 output += (item,) 

41 else: 

42 output += (DTensor.from_local(item, layout.mesh, layout.alias_placements),) 

43 return output 

44 

45 if output_layouts[0] is None: 

46 return py_output 

47 return DTensor.from_local( 

48 py_output, output_layouts[0].mesh, output_layouts[0].alias_placements 

49 )