Coverage for hyper_parallel / core / shard / ops / parallel_squeeze.py: 78%
91 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for ExpandDims operator.
17"""
18from hyper_parallel.core.layout import Layout
19from .parallel_ops import DistributedOp
22class SqueezeDistributedOp(DistributedOp):
23 """Distributed implementation for Squeeze operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layout for Squeeze.
29 Args:
30 layouts (tuple): Tuple containing input layout.
31 extra_args: Extra arguments containing axis and input_shapes.
32 Can be dict or list/tuple where last element is input_shapes.
34 Returns:
35 Layout: Output layout with squeezed dimensions removed.
36 """
37 if not layouts:
38 raise ValueError(
39 f"For {self.op_name}, layouts should contain at least one input layout, "
40 f"but got empty layouts."
41 )
43 x_layout = layouts[0]
44 if x_layout.mesh_shape is None:
45 raise ValueError(
46 f"For {self.op_name}, input layout mesh_shape should not be None, "
47 f"but got None."
48 )
50 axis, input_shape = self._extract_args(extra_args)
51 if input_shape is None:
52 raise ValueError(
53 f"For {self.op_name}, input_shapes should be provided in extra_args, "
54 f"but got None."
55 )
57 return self._compute_squeeze_layout(x_layout, axis, input_shape)
59 def _extract_args(self, extra_args):
60 """Extract axis and input_shape from extra_args."""
61 if isinstance(extra_args, dict):
62 input_shapes = extra_args.get("input_shapes", None)
63 axis = extra_args.get("axis", None)
64 elif isinstance(extra_args, (list, tuple)) and extra_args:
65 # Last element is input_shapes
66 input_shapes = extra_args[-1]
67 if not isinstance(input_shapes, (list, tuple)):
68 raise ValueError(
69 f"For {self.op_name}, input_shapes should be list or tuple, "
70 f"but got {type(input_shapes)}."
71 )
72 # First element is axis (if available)
73 axis = extra_args[0] if len(extra_args) > 1 else None
74 else:
75 raise ValueError(
76 f"For {self.op_name}, extra_args should be dict or list/tuple, "
77 f"but got {type(extra_args)}."
78 )
80 # Get input shape (first element of input_shapes)
81 if input_shapes:
82 input_shape = input_shapes[0] if isinstance(input_shapes[0], (list, tuple)) else input_shapes
83 else:
84 input_shape = None
86 return axis, input_shape
88 def _compute_squeeze_layout(self, x_layout, axis, input_shape):
89 """Compute the squeezed layout."""
90 # Handle scalar case
91 if not input_shape:
92 return self._handle_scalar_case(x_layout, axis)
94 # Validate input_shape matches layout rank
95 self._validate_input_shape(x_layout, input_shape)
97 # Find dimensions to squeeze
98 dims_to_squeeze = self._get_dims_to_squeeze(x_layout, axis, input_shape)
100 # Create output layout
101 return self._create_output_layout(x_layout, dims_to_squeeze)
103 def _handle_scalar_case(self, x_layout, axis):
104 """Handle scalar input case."""
105 if axis is not None and axis != [] and axis != ():
106 raise ValueError(
107 f"For {self.op_name}, axis should be None for scalar input, "
108 f"but got {axis}."
109 )
111 # Return scalar layout
112 output_layout = Layout(
113 mesh_shape=x_layout.mesh_shape,
114 alias_name=x_layout.alias_name,
115 rank_list=x_layout.rank_list
116 )
117 output_layout = output_layout()
118 return output_layout
120 def _validate_input_shape(self, x_layout, input_shape):
121 """Validate that input shape matches layout rank."""
122 x_map = list(x_layout.alias_tensor_map)
123 in_rank = len(x_map)
125 if len(input_shape) != in_rank:
126 raise ValueError(
127 f"For {self.op_name}, input shape rank should match layout rank, "
128 f"but got {len(input_shape)} and {in_rank}."
129 )
131 def _get_dims_to_squeeze(self, x_layout, axis, input_shape):
132 """Get list of dimensions to squeeze."""
133 x_map = list(x_layout.alias_tensor_map)
134 in_rank = len(x_map)
136 if axis is None:
137 return self._get_all_squeezable_dims(x_map, input_shape)
138 return self._get_specified_dims_to_squeeze(x_map, axis, input_shape, in_rank)
140 def _get_all_squeezable_dims(self, x_map, input_shape):
141 """Get all squeezable dimensions when axis is None."""
142 dims_to_squeeze = []
143 for i, shape in enumerate(input_shape):
144 if shape == 1 and x_map[i] == "None":
145 dims_to_squeeze.append(i)
146 return dims_to_squeeze
148 def _get_specified_dims_to_squeeze(self, x_map, axis, input_shape, in_rank):
149 """Get dimensions to squeeze when axis is specified."""
150 # Convert axis to list if it's a single integer
151 if isinstance(axis, int):
152 axis = [axis]
154 # Convert negative indices to positive
155 axis = [ax if ax >= 0 else ax + in_rank for ax in axis]
157 # Validate axis range
158 self._validate_axis_range(axis, in_rank)
160 # Check all specified axes
161 for ax in axis:
162 self._validate_axis_for_squeeze(x_map, input_shape, ax)
164 # Return sorted unique axes
165 return sorted(set(axis))
167 def _validate_axis_range(self, axis, in_rank):
168 """Validate axis values are within range."""
169 for ax in axis:
170 if ax < 0 or ax >= in_rank:
171 raise ValueError(
172 f"For {self.op_name}, axis should be in range [{-in_rank}, {in_rank-1}], "
173 f"but got {ax}."
174 )
176 def _validate_axis_for_squeeze(self, x_map, input_shape, ax):
177 """Validate a specific axis can be squeezed."""
178 # Check shape == 1
179 if input_shape[ax] != 1:
180 raise ValueError(
181 f"For {self.op_name}, dimension should have size 1, "
182 f"but got shape {input_shape[ax]} at dimension {ax}."
183 )
185 # Check mapping is "None" (not distributed)
186 if x_map[ax] != "None":
187 raise ValueError(
188 f"For {self.op_name}, dimension should not be distributed, "
189 f"but got dimension {ax} mapped to device axis {x_map[ax]}."
190 )
192 def _create_output_layout(self, x_layout, dims_to_squeeze):
193 """Create output layout after squeezing dimensions."""
194 # Get current alias tensor map
195 x_map = list(x_layout.alias_tensor_map)
197 # Sort in descending order for safe removal
198 dims_to_squeeze = sorted(set(dims_to_squeeze), reverse=True)
200 # Remove specified dimensions
201 for dim in dims_to_squeeze:
202 del x_map[dim]
204 new_map = x_map
206 # Create output layout with new mapping
207 output_layout = Layout(
208 mesh_shape=x_layout.mesh_shape,
209 alias_name=x_layout.alias_name,
210 rank_list=x_layout.rank_list
211 )
213 if new_map:
214 output_layout = output_layout(*new_map)
215 else:
216 # For scalar result
217 output_layout = output_layout()
219 # Copy partial operations from input layout
220 self._copy_partial_operations(x_layout, output_layout, new_map)
222 return output_layout
224 def _copy_partial_operations(self, x_layout, output_layout, new_map):
225 """Copy partial operations from input to output layout."""
226 for i, partial_op in enumerate(x_layout.partial):
227 if partial_op is not None:
228 dev_axis_name = x_layout.alias_name[i]
229 # Check if this device axis is still used in the output
230 if dev_axis_name in new_map:
231 output_layout.set_partial_by_dev_axis(dev_axis_name, partial_op)