Coverage for hyper_parallel / core / shard / ops / parallel_expand_dims.py: 77%

26 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for ExpandDims operator. 

17""" 

18from hyper_parallel.core.layout import Layout 

19from .parallel_ops import DistributedOp 

20 

21 

22class ExpandDimsDistributedOp(DistributedOp): 

23 """Distributed implementation for ExpandDims operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for ExpandDims. 

28 

29 Args: 

30 layouts (tuple): Tuple containing input layout. 

31 extra_args: axis parameter (int or list/tuple containing int). 

32 

33 Returns: 

34 Layout: Output layout with inserted dimension. 

35 """ 

36 if not layouts: 

37 raise ValueError(f"{self.__class__.__name__} requires at least one input layout") 

38 

39 x_layout = layouts[0] 

40 

41 if x_layout.mesh_shape is None: 

42 raise ValueError("Input layout cannot be None.") 

43 

44 if not extra_args: 

45 raise ValueError(f"{self.op_name}: axis parameter is required") 

46 

47 axis = extra_args[0] if isinstance(extra_args, (tuple, list)) else extra_args 

48 

49 in_rank = len(x_layout.alias_tensor_map) 

50 if axis < 0: 

51 axis = axis + in_rank + 1 

52 

53 if axis < 0 or axis > in_rank: 

54 raise ValueError( 

55 f"{self.op_name}: axis {axis} out of range for input rank {in_rank}. " 

56 f"Valid range is [{-in_rank-1}, {in_rank}]" 

57 ) 

58 

59 x_map = list(x_layout.alias_tensor_map) 

60 x_map.insert(axis, "None") 

61 

62 output_layout = Layout( 

63 mesh_shape=x_layout.mesh_shape, 

64 alias_name=x_layout.alias_name, 

65 rank_list=x_layout.rank_list 

66 ) 

67 output_layout = output_layout(*x_map) 

68 

69 for i, partial_op in enumerate(x_layout.partial): 

70 if partial_op is not None: 

71 dev_axis_name = x_layout.alias_name[i] 

72 output_layout.set_partial_by_dev_axis(dev_axis_name, partial_op) 

73 

74 return output_layout