Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_stack.py: 98%
45 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Stack operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
22# pylint: disable=unused-argument
23def _normalize_stack_args(tensors, dim=0, **kwargs):
24 """
25 Normalize arguments for torch.stack.
26 """
27 return (tensors,), {'dim': dim}
30class StackDistributedOp(DistributedOp):
31 """Distributed implementation for Stack operator."""
33 def preprocess(self, args, kwargs):
34 """
35 Preprocess input arguments and extract local components.
37 Normalizes args, explicitly extracts parameters, and prepares
38 local tensors and cache values without validation logic.
39 """
40 args, kwargs = _normalize_stack_args(*args, **kwargs)
42 # Explicit parameter extraction
43 tensors = args[0]
44 dim = kwargs['dim']
46 # Extract local tensors and layouts
47 local_tensors = tuple(t.to_local() if hasattr(t, "to_local") else t for t in tensors)
48 layouts = [getattr(t, "layout", None) for t in tensors]
50 # Construct local args and kwargs for the inner op execution
51 local_args = (local_tensors,)
52 local_kwargs = {'dim': dim}
54 # Flatten layouts and append dim for caching and inference
55 cache_values = layouts + [dim]
57 return local_args, local_kwargs, cache_values
59 def infer_layout(self, cache_values):
60 """
61 Infer output layout based on cache values for torch.stack.
63 All validation logic (e.g., empty checks, layout consistency,
64 and dimension bounds) is handled here.
65 """
66 layouts = cache_values[:-1]
67 dim = cache_values[-1]
69 # 1. Validation Logic
70 if not layouts:
71 raise ValueError(f"Operation {self.op_name}: stack requires at least one input tensor.")
73 valid_layouts = [lyt for lyt in layouts if lyt is not None]
75 if not valid_layouts:
76 raise ValueError(f"Operation {self.op_name}: stack requires at least one input DTensor.")
78 # Reference layout to validate consistency across all input tensors
79 base_layout = valid_layouts[0]
80 for layout in valid_layouts[1:]:
81 if layout != base_layout:
82 raise ValueError(
83 f"Operation {self.op_name}: All input tensors must have the same layout. "
84 f"Expected layout: {base_layout}, Mismatched layout: {layout}"
85 )
87 ndim = len(base_layout.tensor_map)
89 # Normalize and validate the dimension. For stack, valid range is [-ndim - 1, ndim]
90 actual_dim = dim if dim >= 0 else dim + ndim + 1
91 if actual_dim < 0 or actual_dim > ndim:
92 raise ValueError(
93 f"Operation {self.op_name}: Dimension out of range (expected to be in range of "
94 f"[{-ndim - 1}, {ndim}], but got {dim})"
95 )
97 # 2. Layout Inference Logic
98 in_tensor_map = base_layout.tensor_map
100 # Insert an unsharded mapping (-1) at the newly created dimension
101 output_tensor_map = list(in_tensor_map)
102 output_tensor_map.insert(actual_dim, -1)
104 mesh_shape = base_layout.mesh_shape
105 alias_name = base_layout.alias_name
106 rank_list = base_layout.rank_list
108 def idx_to_alias(idx, aliases):
109 """Map tensor_map index back to the alias string."""
110 if idx == -1:
111 return "None"
112 # Reverse indexing mapped to the framework's layout design
113 return aliases[len(aliases) - idx - 1]
115 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_tensor_map)
117 # Reconstruct the output layout
118 output_layout = Layout(
119 mesh_shape=mesh_shape,
120 alias_name=alias_name,
121 rank_list=rank_list
122 )
124 # Apply the placement mappings via the __call__ method
125 output_layout = output_layout(*output_alias_map)
127 # Returns a tuple of output layouts and None for extra_args
128 return ((output_layout,), None)