Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_chunk_view.py: 90%

41 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for ChunkView operator. 

17""" 

18 

19from .parallel_ops import DistributedOp 

20 

21 

22class ChunkViewDistributedOp(DistributedOp): 

23 """Distributed implementation for ChunkView operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layouts for ChunkView operator. 

28 

29 Rules: 

30 1. Split dimension cannot be sharded. 

31 2. Default: dim = 0 if not specified. 

32 3. Output count may be less than chunks if dimension size < chunks. 

33 

34 Args: 

35 layouts (Layout): Layout of input tensor 

36 extra_args (list): chunks, dim, input_shape. Expected: 

37 extra_args[0]: chunks (required) - number of chunks to split into 

38 extra_args[1]: dim (optional, default=0) - dimension along which to split 

39 extra_args[2][0]: input_shapes (optional) - shape of input tensor 

40 

41 Returns: 

42 tuple: Layouts for output tensors 

43 """ 

44 

45 if not layouts or layouts[0] is None: 

46 raise ValueError("chunk_view requires a valid input tensor layout.") 

47 

48 input_layout = layouts[0] 

49 

50 if len(extra_args) < 1: 

51 raise ValueError("chunk_view requires 'chunks' in extra_args.") 

52 

53 chunks = extra_args[0] 

54 input_shapes = extra_args[-1] if len(extra_args) > 1 else None 

55 dim = extra_args[1] if len(extra_args) > 2 else 0 

56 

57 if input_shapes: 

58 input_shape = input_shapes[0] if isinstance(input_shapes[0], (list, tuple)) else input_shapes 

59 else: 

60 input_shape = None 

61 

62 if not isinstance(chunks, int): 

63 raise TypeError(f"chunks must be an integer, got {type(chunks)}") 

64 

65 if chunks < 1: 

66 raise ValueError(f"chunks must be greater than 0, got {chunks}") 

67 if not isinstance(dim, int): 

68 raise TypeError(f"dim must be an integer, got {type(dim)}") 

69 tensor_map = input_layout.tensor_map 

70 input_dim = len(tensor_map) 

71 

72 if dim < 0: 

73 dim = input_dim + dim 

74 

75 if not 0 <= dim < input_dim: 

76 raise ValueError(f"Dimension out of range (expected [0, {input_dim}), got {dim}).") 

77 

78 if tensor_map[dim] != -1: 

79 raise ValueError(f"Cannot split tensor at sharded axis[{dim}], layout: {input_layout}") 

80 

81 if input_shapes: 

82 input_shape = input_shapes[0] if isinstance(input_shapes[0], (list, tuple)) else input_shapes 

83 else: 

84 input_shape = None 

85 

86 if input_shape is not None: 

87 dim_size = input_shape[dim] 

88 if dim_size == 0: 

89 output_num = chunks 

90 else: 

91 split_size = (dim_size + chunks - 1) // chunks 

92 output_num = max((dim_size + split_size - 1) // split_size, 1) 

93 output_num = min(output_num, chunks) 

94 else: 

95 output_num = chunks 

96 

97 output_layouts = (input_layout,) * output_num 

98 return output_layouts