Coverage for hyper_parallel / core / shard / ops / parallel_unbind.py: 96%
23 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Unbind operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
23class UnbindDistributedOp(DistributedOp):
24 """Distributed implementation for Unbind operator."""
26 def infer_layout(self, layouts, extra_args):
27 """
28 Infer output layouts for Unbind operator.
30 Args:
31 layouts (tuple): Layouts of input tensor.
32 extra_args (list):
33 - If configured with 'WithShape' suffix, the last element is input_shapes (list of tuples).
34 - Preceding elements are scalar arguments (e.g., dim).
36 Returns:
37 tuple: A tuple of Layouts for the output tensors.
38 """
39 # Parse arguments provided by _with_layout_infer_with_shape
40 input_shapes = extra_args[-1]
41 args = extra_args[:-1]
43 # Pytorch unbind(input, dim=0)
44 dim = args[0] if args else 0
46 layout = layouts[0]
47 shape = input_shapes[0]
48 tensor_map = layout.tensor_map
49 alias_tensor_map = layout.alias_tensor_map
50 ndim = len(shape)
52 # Handle negative dimension
53 if dim < 0:
54 dim += ndim
56 if not 0 <= dim < ndim:
57 raise ValueError(f"Dimension out of range (expected to be in range of [0, {ndim-1}], but got {dim})")
59 # Check if the dimension to unbind is sharded.
60 # tensor_map values != -1 indicate the dimension is split across a mesh axis.
61 # We cannot unbind a sharded dimension without explicit redistribution (Gather),
62 # as it would result in different tensor lists on different ranks.
63 if tensor_map[dim] != -1:
64 raise ValueError(
65 f"For 'unbind', the dimension {dim} is sharded (mapped to mesh axis {tensor_map[dim]}). "
66 f"Unbinding a sharded dimension is not supported. "
67 f"Please redistribute the tensor to replicate this dimension first."
68 )
70 # Construct output layout: remove the mapping for the unbound dimension
71 # We use alias_tensor_map to utilize Layout's robust initialization
72 out_alias_map = alias_tensor_map[:dim] + alias_tensor_map[dim+1:]
74 # Create a base layout object to generate the new layout
75 base_layout = Layout(
76 mesh_shape=layout.mesh_shape,
77 alias_name=layout.alias_name,
78 rank_list=layout.rank_list
79 )
81 # Generate the specific output layout
82 out_layout = base_layout(*out_alias_map)
84 # Unbind returns a tuple of tensors, the number of which is the size of the unbound dimension
85 num_outputs = shape[dim]
87 return (out_layout,) * num_outputs