Coverage for hyper_parallel / core / shard / ops / parallel_split.py: 52%
81 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for TopK operator.
17"""
19import math
20from .parallel_ops import DistributedOp
23class SplitWithSizeDistributedOp(DistributedOp):
24 """Distributed implementation for SplitWithSize operator."""
26 def infer_layout(self, layouts, extra_args):
27 """
28 Infer output layouts for Split operator.
30 Rules:
31 1. Shared axis can not be split.
33 Args:
34 layouts (Layout): Layout of input tensor
35 extra_args (list): split size or sections, axis, input shape
37 Returns:
38 tuple: Layouts for output tensors
39 """
41 input_layout = layouts[0]
42 axis = extra_args[1]
43 # Check shared axis can not be split.
44 tensor_map = input_layout.tensor_map
45 if tensor_map[axis] != -1:
46 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}")
48 split_sections = extra_args[0]
49 output_num = len(split_sections)
50 output_layouts = (input_layout,) * output_num
51 return output_layouts
54class SplitWithSizeViewDistributedOp(DistributedOp):
55 """Distributed implementation for SplitWithSizeView operator."""
57 def infer_layout(self, layouts, extra_args):
58 """
59 Infer output layouts for SplitWithSizeView operator.
61 Rules:
62 1. Shared axis can not be split.
64 Args:
65 layouts (Layout): Layout of input tensor
66 extra_args (list): split size or sections, axis, input shape
68 Returns:
69 tuple: Layouts for output tensors
70 """
72 input_layout = layouts[0]
73 axis = extra_args[1]
74 # Check shared axis can not be split.
75 tensor_map = input_layout.tensor_map
76 if tensor_map[axis] != -1:
77 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}")
79 split_sections = extra_args[0]
80 output_num = len(split_sections)
81 output_layouts = (input_layout,) * output_num
82 return output_layouts
85class SplitDistributedOp(DistributedOp):
86 """Distributed implementation for Split operator."""
88 def infer_layout(self, layouts, extra_args):
89 """
90 Infer output layouts for Split operator.
92 Rules:
93 1. Shared axis can not be split.
94 2. Default: dim = 0 if not specified.
96 Args:
97 layouts (Layout): Layout of input tensor
98 extra_args (list): split size or sections, axis, input shape. Expected:
99 extra_args[0]: split_size (required)
100 extra_args[1]: axis (optional)
101 extra_args[2][0]: input_shape
103 Returns:
104 tuple: Layouts for output tensors
105 """
107 if not layouts or layouts[0] is None:
108 raise ValueError("split requires a valid input tensor layout.")
109 input_layout = layouts[0]
111 if len(extra_args) == 2:
112 split_size = extra_args[0]
113 axis = 0 # default
114 input_shape = extra_args[1][0]
115 elif len(extra_args) == 3:
116 split_size = extra_args[0]
117 axis = extra_args[1]
118 input_shape = extra_args[2][0]
119 else:
120 raise ValueError("Split ops extra_args requires 'axis' and contains 'output_num' optionally.")
122 tensor_map = input_layout.tensor_map
123 input_dim = len(tensor_map)
124 if axis < 0:
125 axis = input_dim + axis
126 if not 0 <= axis < input_dim:
127 raise ValueError(f"Dimension out of range (expected [0, {input_dim}), got {axis}).")
129 # Check shared axis can not be split.
130 if tensor_map[axis] != -1:
131 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}")
133 output_num = 1
134 if isinstance(split_size, int):
135 output_num = math.ceil(input_shape[axis] / split_size)
136 elif isinstance(split_size, (list, tuple)):
137 output_num = len(split_size)
139 output_layouts = (input_layout,) * output_num
140 return output_layouts
143class SplitTensorDistributedOp(DistributedOp):
144 """Distributed implementation for SplitTensor operator."""
146 def infer_layout(self, layouts, extra_args):
147 """
148 Infer output layouts for Split operator.
150 Rules:
151 1. Shared axis can not be split.
153 Args:
154 layouts (Layout): Layout of input tensor
155 extra_args (list): split size or sections, axis, input shape
157 Returns:
158 tuple: Layouts for output tensors
159 """
161 input_layout = layouts[0]
162 axis = extra_args[1]
163 # Check shared axis can not be split.
164 tensor_map = input_layout.tensor_map
165 if tensor_map[axis] != -1:
166 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}")
168 split_size = extra_args[0]
169 input_shape = extra_args[2][0]
170 output_num = input_shape[axis] // split_size
171 if input_shape[axis] % split_size != 0:
172 output_num += 1
174 output_layouts = (input_layout,) * output_num
175 return output_layouts
178class SplitTensorViewDistributedOp(DistributedOp):
179 """Distributed implementation for SplitTensorView operator."""
181 def infer_layout(self, layouts, extra_args):
182 """
183 Infer output layouts for SplitTensorView operator.
185 Rules:
186 1. Shared axis can not be split.
188 Args:
189 layouts (Layout): Layout of input tensor
190 extra_args (list): split size or sections, axis, input shape
192 Returns:
193 tuple: Layouts for output tensors
194 """
196 input_layout = layouts[0]
197 axis = extra_args[1]
198 # Check shared axis can not be split.
199 tensor_map = input_layout.tensor_map
200 if tensor_map[axis] != -1:
201 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}")
203 split_size = extra_args[0]
204 input_shape = extra_args[2][0]
205 output_num = input_shape[axis] // split_size
206 if input_shape[axis] % split_size != 0:
207 output_num += 1
209 output_layouts = (input_layout,) * output_num
210 return output_layouts