Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_chunk_view.py: 90%
41 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for ChunkView operator.
17"""
19from .parallel_ops import DistributedOp
22class ChunkViewDistributedOp(DistributedOp):
23 """Distributed implementation for ChunkView operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layouts for ChunkView operator.
29 Rules:
30 1. Split dimension cannot be sharded.
31 2. Default: dim = 0 if not specified.
32 3. Output count may be less than chunks if dimension size < chunks.
34 Args:
35 layouts (Layout): Layout of input tensor
36 extra_args (list): chunks, dim, input_shape. Expected:
37 extra_args[0]: chunks (required) - number of chunks to split into
38 extra_args[1]: dim (optional, default=0) - dimension along which to split
39 extra_args[2][0]: input_shapes (optional) - shape of input tensor
41 Returns:
42 tuple: Layouts for output tensors
43 """
45 if not layouts or layouts[0] is None:
46 raise ValueError("chunk_view requires a valid input tensor layout.")
48 input_layout = layouts[0]
50 if len(extra_args) < 1:
51 raise ValueError("chunk_view requires 'chunks' in extra_args.")
53 chunks = extra_args[0]
54 input_shapes = extra_args[-1] if len(extra_args) > 1 else None
55 dim = extra_args[1] if len(extra_args) > 2 else 0
57 if input_shapes:
58 input_shape = input_shapes[0] if isinstance(input_shapes[0], (list, tuple)) else input_shapes
59 else:
60 input_shape = None
62 if not isinstance(chunks, int):
63 raise TypeError(f"chunks must be an integer, got {type(chunks)}")
65 if chunks < 1:
66 raise ValueError(f"chunks must be greater than 0, got {chunks}")
67 if not isinstance(dim, int):
68 raise TypeError(f"dim must be an integer, got {type(dim)}")
69 tensor_map = input_layout.tensor_map
70 input_dim = len(tensor_map)
72 if dim < 0:
73 dim = input_dim + dim
75 if not 0 <= dim < input_dim:
76 raise ValueError(f"Dimension out of range (expected [0, {input_dim}), got {dim}).")
78 if tensor_map[dim] != -1:
79 raise ValueError(f"Cannot split tensor at sharded axis[{dim}], layout: {input_layout}")
81 if input_shapes:
82 input_shape = input_shapes[0] if isinstance(input_shapes[0], (list, tuple)) else input_shapes
83 else:
84 input_shape = None
86 if input_shape is not None:
87 dim_size = input_shape[dim]
88 if dim_size == 0:
89 output_num = chunks
90 else:
91 split_size = (dim_size + chunks - 1) // chunks
92 output_num = max((dim_size + split_size - 1) // split_size, 1)
93 output_num = min(output_num, chunks)
94 else:
95 output_num = chunks
97 output_layouts = (input_layout,) * output_num
98 return output_layouts