Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_gather.py: 70%
211 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Gather operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from hyper_parallel.platform import get_platform
21from .parallel_ops import DistributedOp
24class IndexSelectDistributedOp(DistributedOp):
25 """Distributed implementation for Index Select operator."""
27 def infer_layout(self, layouts, extra_args=None):
28 """
29 Infer output layouts for Index Select operations.
31 Args:
32 layouts: Layouts of input tensors
33 extra_args: extra_args of input tensors
35 Returns:
36 tuple: Layout for output tensor.
38 Raises:
39 ValueError: If input layouts are not compatible or have partial status.
40 """
41 # Check partial inputs
42 if not self._allow_partial_inputs:
43 self._check_partial_inputs(layouts)
45 # Check inputs
46 if len(layouts) != 3:
47 raise ValueError(f"Gather ops requires 3 layouts, but {len(layouts)}")
48 if len(extra_args) != 1:
49 raise ValueError(f"Gather ops requires 1 extra args, but {len(extra_args)}")
51 # Parse layout info
52 p_layout, i_layout = layouts[0], layouts[2]
53 axis = extra_args[0]
55 p_tensor_map = p_layout.alias_tensor_map
56 i_tensor_map = i_layout.alias_tensor_map
58 # 1. Validate the axis range before any manipulation
59 if axis < -len(p_tensor_map) or axis >= len(p_tensor_map):
60 raise ValueError(
61 f"Operation {self.op_name}: dim value {axis} is out of valid range"
62 )
64 # 2. Convert negative axis to positive index to avoid Python slicing bugs
65 if axis < 0:
66 axis += len(p_tensor_map)
68 if len(i_tensor_map) != 1:
69 raise ValueError(
70 f"Operation {self.op_name}: index is not a one-dimensional Tensor"
71 )
73 # 3. Create output layout map
74 # We allow sharding on the `axis`. Since `index_select` replaces the `axis`
75 # dimension with the `index` dimension, if `axis` was sharded, that mesh
76 # dimension is removed from the output tensor map.
77 output_tensor_map = list(p_tensor_map[:axis]) + list(i_tensor_map) + list(p_tensor_map[axis + 1 :])
79 output_layout = Layout(
80 mesh_shape=p_layout.mesh_shape,
81 alias_name=p_layout.alias_name,
82 rank_list=p_layout.rank_list,
83 )
84 output_layout = output_layout(*output_tensor_map)
86 # 4. Implicit Communication via Partial Layout
87 # If the gather axis was sharded, the local output will only be a masked partial result.
88 # We set the output layout to Partial('sum') for that specific mesh dimension so the
89 # OpDispatcher handles the AllReduce automatically when this tensor is used later.
90 shard_mesh_dim_name = p_tensor_map[axis]
91 if shard_mesh_dim_name != "None":
92 # Handle possible multi-axis sharding tuple
93 if isinstance(shard_mesh_dim_name, tuple):
94 for dim_name in shard_mesh_dim_name:
95 if dim_name != "None":
96 output_layout.set_partial_by_dev_axis(dim_name, 'sum')
97 else:
98 output_layout.set_partial_by_dev_axis(shard_mesh_dim_name, 'sum')
100 return output_layout
102 def get_expand_impl(self, func, infer_result, layouts, extra_args=None):
103 """
104 Get the expanded execution implementation for Index Select.
105 """
106 p_layout = layouts[0]
107 axis = extra_args[0]
108 if axis < 0:
109 axis += len(p_layout.alias_tensor_map)
111 shard_mesh_dim_name = p_layout.alias_tensor_map[axis]
113 # If the axis is NOT sharded, fallback to standard execution
114 if shard_mesh_dim_name == "None":
115 return func
117 # If the axis IS sharded, return a custom function with Masking ONLY.
118 # The explicit AllReduce is completely removed.
119 def expand_impl(input_tensor, dim, index, **kwargs):
120 platform = get_platform()
121 mesh = p_layout.mesh
123 # Fetch the communication group for the sharded mesh dimension
124 if isinstance(shard_mesh_dim_name, tuple):
125 target_dim_name = next(d for d in shard_mesh_dim_name if d != "None")
126 else:
127 target_dim_name = shard_mesh_dim_name
129 comm_group_info = mesh.get_comm_group_by_axis(target_dim_name)
130 group = comm_group_info.group if hasattr(comm_group_info, 'group') else comm_group_info
132 # Get the rank of the current device within this specific communication group
133 group_rank = platform.get_group_local_rank(group=group)
135 # Calculate global index boundaries for the local chunk
136 local_dim_size = input_tensor.shape[dim]
137 start_idx = group_rank * local_dim_size
138 end_idx = start_idx + local_dim_size
140 # 1. Compute mask: True for indices that belong to the current rank
141 mask = (index >= start_idx) & (index < end_idx)
143 # 2. Shift global indices to local indices
144 safe_index = index - start_idx
146 # Clamp safe_index to valid local ranges to prevent CUDA out-of-bounds
147 # errors during the local index_select (invalid ones will be masked out anyway).
148 safe_index = safe_index.clamp(min=0, max=local_dim_size - 1)
150 # 3. Perform local index_select using tensor's built-in method
151 local_out = input_tensor.index_select(dim, safe_index, **kwargs)
153 # 4. Mask out the invalid indices (set them to 0)
154 # Reshape the 1D mask to broadcast against the output shape
155 mask_shape = [1] * local_out.ndim
156 mask_shape[dim] = -1
157 mask_reshaped = mask.reshape(mask_shape).to(local_out.dtype)
159 local_out = local_out * mask_reshaped
161 # Return the partial local tensor directly. The framework's layout engine
162 # and OpDispatcher will trigger the AllReduce when this Partial tensor
163 # is redistributed to a non-partial layout.
164 return local_out
166 return expand_impl
169class GatherDDistributedOp(DistributedOp):
170 """Distributed implementation for GatherD operator.
172 GatherD gathers values along a specified axis from the input tensor using the index tensor.
174 Signature: GatherD(input, dim, index) -> output
176 Key constraints:
177 - Input and index must have the same number of dimensions
178 - Output inherits the sharding pattern of the input tensor
179 """
181 def infer_layout(self, layouts, extra_args=None):
182 """
183 Infer output layouts for GatherD operations.
184 Args:
185 layouts: Layouts of input tensors [input_layout, dim_layout, index_layout]
186 extra_args: Extra arguments containing [dim]
187 Returns:
188 Layout: Layout for output tensor.
189 Raises:
190 ValueError: If input layouts are not compatible or have partial status.
191 """
192 # Check partial inputs
193 if not self._allow_partial_inputs:
194 self._check_partial_inputs(layouts)
196 # Validate input count
197 if len(layouts) != 3:
198 raise ValueError(
199 f"Operation {self.op_name}: requires 3 layouts (input, dim, index), "
200 f"but got {len(layouts)}"
201 )
202 # Validate extra_args (should contain dim)
203 if len(extra_args) != 1:
204 raise ValueError(
205 f"Operation {self.op_name}: requires 1 extra arg (dim), "
206 f"but got {len(extra_args)}"
207 )
208 # Parse layouts: [input, dim (non-tensor), index]
209 # Note: dim is a scalar, so layouts[1] should be None
210 input_layout = layouts[0]
211 index_layout = layouts[2]
212 dim = extra_args[0]
213 # Validate layouts exist
214 if input_layout is None or not hasattr(input_layout, "tensor_map"):
215 raise ValueError(f"Operation {self.op_name}: input layout cannot be None")
216 if index_layout is None or not hasattr(index_layout, "tensor_map"):
217 raise ValueError(f"Operation {self.op_name}: index layout cannot be None")
218 input_tensor_map = input_layout.tensor_map
219 index_tensor_map = index_layout.tensor_map
220 # Validate same rank
221 if len(input_tensor_map) != len(index_tensor_map):
222 raise ValueError(
223 f"Operation {self.op_name}: input and index must have the same number of dimensions. "
224 f"Got input rank={len(input_tensor_map)}, index rank={len(index_tensor_map)}"
225 )
226 # Validate dim is in valid range
227 rank = len(input_tensor_map)
228 if dim < -rank or dim >= rank:
229 raise ValueError(
230 f"Operation {self.op_name}: dim value {dim} is out of valid range [{-rank}, {rank-1}]"
231 )
232 # Normalize negative dim
233 if dim < 0:
234 dim = dim + rank
235 for axis, (input_axis_map, index_axis_map) in enumerate(zip(input_tensor_map, index_tensor_map)):
236 if axis == dim:
237 continue
238 if input_axis_map != index_axis_map:
239 raise ValueError(
240 f"Operation {self.op_name}: input and index must use the same sharding on non-dim axis {axis}. "
241 f"Got input tensor_map={input_tensor_map}, index tensor_map={index_tensor_map}, dim={dim}"
242 )
243 # Output inherits index layout
244 output_layout = Layout(
245 mesh_shape=index_layout.mesh_shape,
246 alias_name=index_layout.alias_name,
247 rank_list=index_layout.rank_list,
248 )
249 output_layout.set_tensor_map(index_layout.tensor_map)
250 if input_tensor_map[dim] != -1:
251 # pylint: disable=protected-access
252 # Inherit current partial state from index layout
253 output_layout._partial = list(index_layout.partial)
254 # Calculate the device axis name for the dim dimension
255 # tensor_map uses reverse indexing: tensor_map[i] = len(alias_name) - 1 - device_axis
256 device_axis_idx = len(index_layout.alias_name) - 1 - input_tensor_map[dim]
257 dim_axis_name = index_layout.alias_name[device_axis_idx]
258 output_layout.set_partial_by_dev_axis(dim_axis_name, 'sum')
259 # pylint: disable=protected-access
260 # Rebuild readable alias tensor map
261 output_layout._alias_tensor_map = output_layout._build_readable_tensor_map()
262 # pylint: disable=protected-access
263 # Sync tensor_map to placement representation
264 output_layout.tensor_map_to_placement()
265 # Update compact string description
266 output_layout.update_compact_str()
267 return output_layout
269 def get_expand_impl(self, func, infer_result, layouts, extra_args=None):
270 """
271 Returns the execution implementation wrapper for distributed GatherD.
273 When the dim axis is sharded, each rank gathers from its local slice of the input tensor.
274 The indices need to be adjusted to account for the local partition offset.
276 Args:
277 func: The original GatherD function to wrap
278 output_layout: The inferred output layout
279 layouts: Layouts of input tensors [input_layout, dim_layout, index_layout]
280 extra_args: Extra arguments containing [dim]
282 Returns:
283 Callable: Distributed implementation wrapper, or None if no sharding
284 """
285 input_layout = layouts[0]
286 dim = extra_args[0]
287 # Get tensor maps
288 input_tensor_map = input_layout.tensor_map
289 # Check if dim axis is sharded (enhanced MP)
290 # tensor_map[dim] == -1 means replicated, otherwise sharded
291 if input_tensor_map[dim] == -1: # native sharding, no need for custom implementation
292 return None
294 def distributed_gatherd_impl(*args, **kwargs):
295 """
296 Distributed GatherD implementation for sharded dim axis.
298 Each rank gathers from its local slice of input tensor.
299 Indices are adjusted by subtracting the local partition offset.
300 """
301 input_tensor = args[0]
302 index_tensor = args[2]
303 # Calculate local partition offset for the dim axis
304 mesh = input_layout.mesh
305 # Convert tensor_map index to mesh axis index (reverse order)
306 mesh_dim_idx = len(mesh.mesh_shape) - 1 - input_tensor_map[dim]
307 # Get the coordinate of current rank along the mesh dimension
308 dim_coord = mesh.get_local_rank(mesh_dim_idx)
309 # Calculate the size of input tensor's dim dimension per partition
310 input_dim_size = input_tensor.shape[dim]
311 # Calculate the starting index of local partition
312 local_start_index = int(dim_coord * input_dim_size)
313 local_end_index = int(local_start_index + input_dim_size)
314 # Adjust indices: subtract local_start_index to map global indices to local range
315 # This is similar to how Embedding shifts indices for Row Parallelism
316 adjusted_index = index_tensor - local_start_index
317 # Create mask to identify out-of-bounds indices
318 # Indices outside [0, local_dim_size) belong to other partitions
319 mask = (index_tensor >= local_start_index) & (index_tensor < local_end_index)
320 # Cross-platform cast to matching int dtype
321 mask_int = mask.to(index_tensor.dtype)
322 # Zero out invalid indices to prevent out-of-bounds access
323 safe_index = adjusted_index * mask_int
324 # Replace original index tensor with adjusted index
325 new_args = list(args)
326 new_args[2] = safe_index
327 # Execute native GatherD with adjusted indices
328 output = func(*new_args, **kwargs)
329 # Zero out outputs corresponding to invalid indices
330 mask_int = mask_int.to(output.dtype)
331 output = output * mask_int
332 return output
333 return distributed_gatherd_impl
336class GatherNdDistributedOp(DistributedOp):
337 """Distributed implementation for GatherNd operator."""
339 def infer_layout(self, layouts, extra_args=None):
340 """
341 Infer output layout for GatherNd.
343 For GatherNd: out.shape = indices.shape[:-1] + input_x.shape[K:], where K = indices.shape[-1].
345 This implementation:
346 - Inherits sharding from indices[:-1].
347 - Allows sharding on input_x trailing dims input_x[K:].
348 - Requires input_x[:K] to be replicated ("None") if input_layout is provided.
349 - Requires indices[-1] (K dim) to be replicated ("None").
351 Output Layout:
352 output_tensor_map = indices_tensor_map[:-1] + input_tensor_map[K:]
353 If input_layout is None, input trailing dims are treated as replicated ("None").
354 """
355 input_layout, indices_layout = self._parse_input_layouts(layouts)
357 input_shape, indices_shape = self._get_input_shapes(extra_args)
358 k, trail_rank = self._get_k_and_trailing_rank(input_shape, indices_shape)
360 input_tensor_map, indices_tensor_map = self._validate_tensor_maps(
361 input_layout, indices_layout, k
362 )
364 # Output sharding: inherit indices[:-1] + input_x[K:].
365 if input_tensor_map is None:
366 output_tensor_map = tuple(indices_tensor_map[:-1]) + ("None",) * trail_rank
367 else:
368 output_tensor_map = tuple(indices_tensor_map[:-1]) + tuple(input_tensor_map[k:])
370 output_layout = Layout(
371 mesh_shape=indices_layout.mesh_shape,
372 alias_name=indices_layout.alias_name,
373 rank_list=indices_layout.rank_list,
374 )
376 if output_tensor_map:
377 output_layout = output_layout(*output_tensor_map)
378 else:
379 output_layout = output_layout("None")
381 return output_layout
383 def _parse_input_layouts(self, layouts):
384 """Parse and validate input layouts."""
385 if len(layouts) < 2:
386 raise ValueError(
387 f"Operation {self.op_name} requires at least 2 input layouts, but got {len(layouts)}"
388 )
390 input_layout, indices_layout = layouts[0], layouts[1]
392 # Extra inputs are allowed only when they are non-tensor args (layout is None).
393 for extra_layout in layouts[2:]:
394 if extra_layout is not None:
395 raise ValueError(
396 f"Operation {self.op_name} only supports 2 tensor inputs, but got extra tensor layout: "
397 f"{extra_layout}"
398 )
400 # For GatherNd: input_layout can be None (treated as fully replicated), but indices_layout must exist.
401 if indices_layout is None or not hasattr(indices_layout, "alias_tensor_map"):
402 raise ValueError(f"Operation {self.op_name}: Indices layout cannot be None")
404 return input_layout, indices_layout
406 def _validate_tensor_maps(self, input_layout, indices_layout, k):
407 """Validate tensor maps constraints for GatherNd."""
408 indices_tensor_map = indices_layout.alias_tensor_map
410 # Validate: indices tensor_map must exist and last dimension cannot be split.
411 if not indices_tensor_map:
412 raise ValueError(f"Operation {self.op_name}: indices tensor_map cannot be empty")
414 last_axis = indices_tensor_map[-1]
415 if not self._is_none_axis(last_axis):
416 raise ValueError(
417 f"Operation {self.op_name}: The last dimension of indices cannot be split. "
418 f"Got indices[-1] = {last_axis}"
419 )
421 # Validate input only when layout is provided.
422 input_tensor_map = None
423 if input_layout is not None:
424 input_tensor_map = input_layout.alias_tensor_map
426 if k > len(input_tensor_map):
427 raise ValueError(
428 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank "
429 f"({len(input_tensor_map)})"
430 )
432 # Indexed dims [0:K) must be replicated.
433 for axis_name in input_tensor_map[:k]:
434 if not self._is_none_axis(axis_name):
435 raise ValueError(
436 f"Operation {self.op_name}: input_x cannot be split on indexed dims [0:{k}). "
437 f"These dims must be 'None', but got tensor_map: {input_tensor_map}"
438 )
440 return input_tensor_map, indices_tensor_map
442 def _get_input_shapes(self, extra_args):
443 """Get input and indices shapes from extra_args (WithShape suffix required)."""
444 input_shapes = None
445 if extra_args and hasattr(extra_args[-1], "__len__") and len(extra_args[-1]) >= 2:
446 input_shapes = extra_args[-1]
448 if input_shapes is None:
449 raise ValueError(
450 f"Operation {self.op_name}: missing input_shapes in extra_args. "
451 f"Please configure yaml with infer_layout_suffix: WithShape."
452 )
454 input_shape = input_shapes[0]
455 indices_shape = input_shapes[1]
456 if input_shape is None or indices_shape is None:
457 raise ValueError(f"Operation {self.op_name}: input_shapes contains None: {input_shapes}")
459 input_shape = self._normalize_shape(input_shape, "input")
460 indices_shape = self._normalize_shape(indices_shape, "indices")
462 if len(indices_shape) < 1:
463 raise ValueError(f"Operation {self.op_name}: indices shape invalid: {indices_shape}")
465 return input_shape, indices_shape
467 def _normalize_shape(self, shape, name):
468 """Normalize shape-like object to tuple of int."""
469 try:
470 norm = tuple(shape)
471 except TypeError as err:
472 raise ValueError(f"Operation {self.op_name}: {name} shape is not iterable: {shape}") from err
474 try:
475 norm = tuple(int(dim) for dim in norm)
476 except (TypeError, ValueError) as err:
477 raise ValueError(f"Operation {self.op_name}: {name} shape contains non-integer dims: {norm}") from err
479 return norm
481 def _get_k_and_trailing_rank(self, input_shape, indices_shape):
482 """Compute K and trailing rank = len(input_shape) - K, where K is indices_shape[-1]."""
483 k = indices_shape[-1]
484 try:
485 k = int(k)
486 except (TypeError, ValueError) as err:
487 raise ValueError(f"Operation {self.op_name}: indices last dim (K) is invalid: {k}") from err
489 if k <= 0:
490 raise ValueError(f"Operation {self.op_name}: indices last dim (K) must be positive, but got {k}")
492 trail_rank = len(input_shape) - k
493 if trail_rank < 0:
494 raise ValueError(
495 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank ({len(input_shape)})"
496 )
498 return k, trail_rank
500 def _is_none_axis(self, axis_name):
501 """
502 Check if an axis name represents no sharding.
503 """
504 if axis_name == "None":
505 return True
507 if isinstance(axis_name, tuple):
508 return all(name == "None" for name in axis_name)
510 return False