Coverage for hyper_parallel / core / shard / ops / parallel_gather.py: 50%
131 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Gather operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
23class IndexSelectDistributedOp(DistributedOp):
24 """Distributed implementation for Index Select operator."""
26 def infer_layout(self, layouts, extra_args):
27 """
28 Infer output layouts for Index Select operations.
30 Args:
31 layouts: Layouts of input tensors
32 extra_args: extra_args of input tensors
34 Returns:
35 tuple: Layout for output tensor.
37 Raises:
38 ValueError: If input layouts are not compatible or have partial status.
39 """
40 # Check partial inputs
41 if not self._allow_partial_inputs:
42 self._check_partial_inputs(layouts)
44 # Check
45 if len(layouts) != 3:
46 raise ValueError(f"Gather ops requires 3 layouts, but {len(layouts)}")
47 if len(extra_args) != 1:
48 raise ValueError(f"Gather ops requires 1 extra args, but {len(extra_args)}")
50 # Parse layout info
51 p_layout, i_layout = layouts[0], layouts[2]
52 axis, batch_dims = extra_args[0], 0
54 p_tensor_map = p_layout.alias_tensor_map
55 i_tensor_map = i_layout.alias_tensor_map
57 # Create output layout
58 if p_tensor_map[axis] != "None":
59 raise ValueError(
60 f"Operation {self.op_name}: Cannot perform sharding on params along the axis"
61 )
63 if len(i_tensor_map) != 1:
64 raise ValueError(
65 f"Operation {self.op_name}: index is not a one-dimensional Tensor"
66 )
68 if axis < -len(p_tensor_map) or axis >= len(p_tensor_map):
69 raise ValueError(
70 f"Operation {self.op_name}: dim value is out of valid range"
71 )
73 output_tensor_map = (
74 p_tensor_map[:axis] + i_tensor_map[batch_dims:] + p_tensor_map[axis + 1 :]
75 )
76 output_layout = i_layout
77 output_layout = Layout(
78 mesh_shape=output_layout.mesh_shape,
79 alias_name=output_layout.alias_name,
80 rank_list=output_layout.rank_list,
81 )
82 output_layout = output_layout(*output_tensor_map)
83 return output_layout
86class GatherDistributedOp(DistributedOp):
87 """Distributed implementation for Gather operator."""
89 def infer_layout(self, layouts, extra_args):
90 """
91 Infer output layouts for Gather operations.
93 Args:
94 layouts: Layouts of input tensors
95 extra_args: extra_args of input tensors
97 Returns:
98 tuple: Layout for output tensor.
100 Raises:
101 ValueError: If input layouts are not compatible or have partial status.
102 """
103 # Check partial inputs
104 if not self._allow_partial_inputs:
105 self._check_partial_inputs(layouts)
107 # Check
108 if len(layouts) != 3:
109 raise ValueError(f"Gather ops requires 3 layouts, but {len(layouts)}")
110 if len(extra_args) != 1:
111 raise ValueError(f"Gather ops requires 1 extra args, but {len(extra_args)}")
113 # Parse layout info
114 p_layout, i_layout = layouts[0], layouts[2]
115 axis = extra_args[0]
117 p_tensor_map = p_layout.alias_tensor_map
118 i_tensor_map = i_layout.alias_tensor_map
120 # Create output layout
121 if p_tensor_map[axis] != "None":
122 raise ValueError(
123 f"Operation {self.op_name}: Cannot perform sharding on params along the axis"
124 )
126 if len(p_tensor_map) != len(i_tensor_map):
127 raise ValueError(
128 f"Operation {self.op_name}: input and index must have the same number of dimensions"
129 )
131 if axis < -len(p_tensor_map) or axis >= len(p_tensor_map):
132 raise ValueError(
133 f"Operation {self.op_name}: dim value is out of valid range"
134 )
136 output_tensor_map = i_tensor_map
137 output_layout = i_layout
138 output_layout = Layout(
139 mesh_shape=output_layout.mesh_shape,
140 alias_name=output_layout.alias_name,
141 rank_list=output_layout.rank_list,
142 )
143 output_layout = output_layout(*output_tensor_map)
144 return output_layout
147class GatherNdDistributedOp(DistributedOp):
148 """Distributed implementation for GatherNd operator."""
150 def infer_layout(self, layouts, extra_args):
151 """
152 Infer output layout for GatherNd.
154 For GatherNd: out.shape = indices.shape[:-1] + input_x.shape[K:], where K = indices.shape[-1].
156 This implementation:
157 - Inherits sharding from indices[:-1].
158 - Allows sharding on input_x trailing dims input_x[K:].
159 - Requires input_x[:K] to be replicated ("None") if input_layout is provided.
160 - Requires indices[-1] (K dim) to be replicated ("None").
162 Output Layout:
163 output_tensor_map = indices_tensor_map[:-1] + input_tensor_map[K:]
164 If input_layout is None, input trailing dims are treated as replicated ("None").
165 """
166 input_layout, indices_layout = self._parse_input_layouts(layouts)
168 input_shape, indices_shape = self._get_input_shapes(extra_args)
169 k, trail_rank = self._get_k_and_trailing_rank(input_shape, indices_shape)
171 input_tensor_map, indices_tensor_map = self._validate_tensor_maps(
172 input_layout, indices_layout, k
173 )
175 # Output sharding: inherit indices[:-1] + input_x[K:].
176 if input_tensor_map is None:
177 output_tensor_map = tuple(indices_tensor_map[:-1]) + ("None",) * trail_rank
178 else:
179 output_tensor_map = tuple(indices_tensor_map[:-1]) + tuple(input_tensor_map[k:])
181 output_layout = Layout(
182 mesh_shape=indices_layout.mesh_shape,
183 alias_name=indices_layout.alias_name,
184 rank_list=indices_layout.rank_list,
185 )
187 if output_tensor_map:
188 output_layout = output_layout(*output_tensor_map)
189 else:
190 output_layout = output_layout("None")
192 return output_layout
194 def _parse_input_layouts(self, layouts):
195 """Parse and validate input layouts."""
196 if len(layouts) < 2:
197 raise ValueError(
198 f"Operation {self.op_name} requires at least 2 input layouts, but got {len(layouts)}"
199 )
201 input_layout, indices_layout = layouts[0], layouts[1]
203 # Extra inputs are allowed only when they are non-tensor args (layout is None).
204 for extra_layout in layouts[2:]:
205 if extra_layout is not None:
206 raise ValueError(
207 f"Operation {self.op_name} only supports 2 tensor inputs, but got extra tensor layout: "
208 f"{extra_layout}"
209 )
211 # For GatherNd: input_layout can be None (treated as fully replicated), but indices_layout must exist.
212 if indices_layout is None or not hasattr(indices_layout, "alias_tensor_map"):
213 raise ValueError(f"Operation {self.op_name}: Indices layout cannot be None")
215 return input_layout, indices_layout
217 def _validate_tensor_maps(self, input_layout, indices_layout, k):
218 """Validate tensor maps constraints for GatherNd."""
219 indices_tensor_map = indices_layout.alias_tensor_map
221 # Validate: indices tensor_map must exist and last dimension cannot be split.
222 if not indices_tensor_map:
223 raise ValueError(f"Operation {self.op_name}: indices tensor_map cannot be empty")
225 last_axis = indices_tensor_map[-1]
226 if not self._is_none_axis(last_axis):
227 raise ValueError(
228 f"Operation {self.op_name}: The last dimension of indices cannot be split. "
229 f"Got indices[-1] = {last_axis}"
230 )
232 # Validate input only when layout is provided.
233 input_tensor_map = None
234 if input_layout is not None:
235 input_tensor_map = input_layout.alias_tensor_map
237 if k > len(input_tensor_map):
238 raise ValueError(
239 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank "
240 f"({len(input_tensor_map)})"
241 )
243 # Indexed dims [0:K) must be replicated.
244 for axis_name in input_tensor_map[:k]:
245 if not self._is_none_axis(axis_name):
246 raise ValueError(
247 f"Operation {self.op_name}: input_x cannot be split on indexed dims [0:{k}). "
248 f"These dims must be 'None', but got tensor_map: {input_tensor_map}"
249 )
251 return input_tensor_map, indices_tensor_map
253 def _get_input_shapes(self, extra_args):
254 """Get input and indices shapes from extra_args (WithShape suffix required)."""
255 input_shapes = None
256 if extra_args and hasattr(extra_args[-1], "__len__") and len(extra_args[-1]) >= 2:
257 input_shapes = extra_args[-1]
259 if input_shapes is None:
260 raise ValueError(
261 f"Operation {self.op_name}: missing input_shapes in extra_args. "
262 f"Please configure yaml with infer_layout_suffix: WithShape."
263 )
265 input_shape = input_shapes[0]
266 indices_shape = input_shapes[1]
267 if input_shape is None or indices_shape is None:
268 raise ValueError(f"Operation {self.op_name}: input_shapes contains None: {input_shapes}")
270 input_shape = self._normalize_shape(input_shape, "input")
271 indices_shape = self._normalize_shape(indices_shape, "indices")
273 if len(indices_shape) < 1:
274 raise ValueError(f"Operation {self.op_name}: indices shape invalid: {indices_shape}")
276 return input_shape, indices_shape
278 def _normalize_shape(self, shape, name):
279 """Normalize shape-like object to tuple of int."""
280 try:
281 norm = tuple(shape)
282 except TypeError as err:
283 raise ValueError(f"Operation {self.op_name}: {name} shape is not iterable: {shape}") from err
285 try:
286 norm = tuple(int(dim) for dim in norm)
287 except (TypeError, ValueError) as err:
288 raise ValueError(f"Operation {self.op_name}: {name} shape contains non-integer dims: {norm}") from err
290 return norm
292 def _get_k_and_trailing_rank(self, input_shape, indices_shape):
293 """Compute K and trailing rank = len(input_shape) - K, where K is indices_shape[-1]."""
294 k = indices_shape[-1]
295 try:
296 k = int(k)
297 except (TypeError, ValueError) as err:
298 raise ValueError(f"Operation {self.op_name}: indices last dim (K) is invalid: {k}") from err
300 if k <= 0:
301 raise ValueError(f"Operation {self.op_name}: indices last dim (K) must be positive, but got {k}")
303 trail_rank = len(input_shape) - k
304 if trail_rank < 0:
305 raise ValueError(
306 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank ({len(input_shape)})"
307 )
309 return k, trail_rank
311 def _is_none_axis(self, axis_name):
312 """
313 Check if an axis name represents no sharding.
314 """
315 if axis_name == "None":
316 return True
318 if isinstance(axis_name, tuple):
319 return all(name == "None" for name in axis_name)
321 return False