Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_reshape.py: 82%
166 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Reshape operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from hyper_parallel.platform import get_platform
21from .parallel_ops import DistributedOp
22platform = get_platform()
23Tensor = platform.Tensor
26def _filter_none_split_tensor_map(tensor_map, mesh_shape):
27 """
28 Filter out the elements in tensor_map where the size of the corresponding dimension in device_matrix is 1.
30 Args:
31 tensor_map (list): A list of tensor mappings, which may contain integers or tuples.
32 device_matrix (list): A device matrix representing the device distribution across each dimension.
34 Returns:
35 list: The filtered list of tensor mappings, where invalid mappings are replaced with -1 or valid mappings are
36 retained.
37 """
38 filtered_tensor_map = []
39 for item in tensor_map:
40 if isinstance(item, tuple):
41 filtered = []
42 for i in item:
43 if mesh_shape[-1 - i] != 1:
44 filtered.append(i)
45 if len(filtered) == 0:
46 filtered_tensor_map.append(-1)
47 elif len(filtered) == 1:
48 filtered_tensor_map.append(filtered[0])
49 else:
50 filtered_tensor_map.append(tuple(filtered))
51 else:
52 filtered_tensor_map.append(item if mesh_shape[-1 - item] != 1 else -1)
53 return filtered_tensor_map
56class ReshapeDistributedOp(DistributedOp):
57 """Distributed implementation for Reshape operator."""
59 def __init__(self, op_name):
60 super().__init__(op_name)
61 self._allow_partial_inputs = True
63 def _get_dynamic_shape_info(self, shape):
64 total_size = 1
65 dynamic_axis = -1
66 for axis, s in enumerate(shape):
67 total_size *= s
68 if s < 0:
69 dynamic_axis = axis
70 return total_size < 0, dynamic_axis, total_size
72 def _handle_dynamic_shape(self, input_shape, output_shape):
73 """
74 Check dynamic shape. Calculate unknown axis if one of input and output shape is known. If both are unknown,
75 calculate the relative multiple.
76 [2, -1, 8], [4, -1, 8] -> [2, -2, 8], [4, -1, 8]
77 """
78 input_shape = list(input_shape)
79 output_shape = list(output_shape)
80 is_input_dynamic, input_dynamic_axis, input_total_size = self._get_dynamic_shape_info(input_shape)
81 is_output_dynamic, output_dynamic_axis, output_total_size = self._get_dynamic_shape_info(output_shape)
82 dynamic_can_shard = False
83 if not is_input_dynamic and not is_output_dynamic:
84 if input_total_size != output_total_size:
85 raise ValueError(f"The total elements number of input shape {input_shape} and output shape "
86 f"{output_shape} are different.")
87 return input_shape, output_shape, dynamic_can_shard
89 if not is_input_dynamic:
90 accurate_output_shape = output_shape
91 accurate_output_shape[output_dynamic_axis] = -input_total_size // output_total_size
92 return input_shape, accurate_output_shape, dynamic_can_shard
94 if not is_output_dynamic:
95 accurate_input_shape = input_shape
96 accurate_input_shape[input_dynamic_axis] = -output_total_size // input_total_size
97 return accurate_input_shape, output_shape, dynamic_can_shard
99 if output_total_size >= input_total_size:
100 output_shape[output_dynamic_axis] = -(input_total_size // output_total_size)
101 dynamic_can_shard = True
102 else:
103 input_shape[input_dynamic_axis] = -(output_total_size // input_total_size)
104 return input_shape, output_shape, dynamic_can_shard
106 def _merge_unshared_axis(self, global_shape, tensor_map):
107 """
108 Merge those axes that are not sharded to the high dimension which is shared.
109 shape[4, 2, 6, 8], tensor map[-1, -1, 0, -1] -> merged shape[8, 48]
110 """
111 merged_size = 1
112 merged_shape = []
113 merged_tensor_map = []
114 for axis in range(len(global_shape) - 1, -1, -1):
115 merged_size *= global_shape[axis]
116 if tensor_map[axis] != -1:
117 merged_shape.insert(0, merged_size)
118 merged_tensor_map.insert(0, tensor_map[axis])
119 merged_size = 1
120 if tensor_map[0] == -1:
121 merged_shape.insert(0, merged_size)
122 merged_tensor_map.insert(0, -1)
123 return merged_shape, merged_tensor_map
126 def _cal_output_layout_and_dst_shape(self, output_tensor_map, dst_shape, x_dict):
127 """
128 calculate output layout tensor map and local dst shape.
129 """
130 x_mesh_shape = x_dict["mesh_shape"]
131 output_map = []
132 local_dst_shape = []
133 for idx, map_id in enumerate(output_tensor_map):
134 if isinstance(map_id, tuple):
135 shard_size = 1
136 map_idx = []
137 for shard_id in map_id:
138 map_idx.append(x_dict["alias_name"][-1 - shard_id])
139 shard_size *= x_mesh_shape[-1 - shard_id]
140 output_map.append(tuple(map_idx))
141 local_dst_shape.append(dst_shape[idx] // shard_size if dst_shape[idx] > 0 else -1)
142 continue
143 if map_id < 0:
144 output_map.append("None")
145 local_dst_shape.append(dst_shape[idx] if dst_shape[idx] > 0 else -1)
146 else:
147 output_map.append(x_dict["alias_name"][-1 - map_id])
148 local_dst_shape.append(dst_shape[idx] // x_mesh_shape[-1 - map_id] if dst_shape[idx] > 0 else -1)
149 return output_map, local_dst_shape
151 def _parse_shape_args(self, extra_args):
152 """Parse shape arguments from extra_args.
154 Args:
155 extra_args: Extra arguments containing shape info
157 Returns:
158 tuple: (dst_shape, input_shape)
159 """
160 if self.op_name in ["reshape", "view"]:
161 return self._parse_torch_shape_args(extra_args)
162 return self._parse_mindspore_shape_args(extra_args)
164 def _parse_torch_shape_args(self, extra_args):
165 """Parse PyTorch style shape arguments."""
166 if len(extra_args) < 2:
167 raise ValueError(f"{self.op_name} requires output shape and input shape.")
169 input_shape = extra_args[-1]
170 shape_args = extra_args[:-1]
172 if len(shape_args) == 1:
173 first_arg = shape_args[0]
174 if isinstance(first_arg, (list, tuple)):
175 dst_shape = first_arg
176 elif isinstance(first_arg, Tensor):
177 dst_shape = first_arg.tolist()
178 else:
179 dst_shape = shape_args
180 else:
181 dst_shape = shape_args
183 return dst_shape, input_shape
185 def _parse_mindspore_shape_args(self, extra_args):
186 """Parse MindSpore style shape arguments."""
187 if len(extra_args) != 2:
188 raise ValueError("Reshape requires output shape and input shape.")
190 return extra_args[0], extra_args[1]
192 def _normalize_shape(self, dst_shape):
193 """Normalize dst_shape to list format."""
194 if isinstance(dst_shape, Tensor):
195 dst_shape = dst_shape.tolist()
196 if not isinstance(dst_shape, (list, tuple)):
197 raise ValueError("Shape should be a tensor or a tuple or a list.")
198 return dst_shape
200 def _compute_output_tensor_map(self, merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard,
201 input_shape, x_map):
202 """Compute output tensor_map from merged information.
204 Args:
205 merged_shape: Merged shape from _merge_unshared_axis
206 merge_tensor_map: Merged tensor_map from _merge_unshared_axis
207 dst_shape: Target shape
208 x_mesh_shape: Mesh shape
209 dynamic_can_shard: Whether dynamic shape can be sharded
210 input_shape: Original input shape
211 x_map: Input tensor_map
213 Returns:
214 list: Output tensor_map
215 """
216 output_tensor_map = []
217 cur_axis = len(merged_shape) - 1
218 cur_size = merged_shape[cur_axis]
220 for shape in reversed(dst_shape):
221 if cur_size % shape != 0:
222 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
223 cur_size = cur_size // shape
225 if cur_size == 1:
226 map_val = self._handle_sharded_axis(
227 merge_tensor_map, cur_axis, x_mesh_shape, shape, dynamic_can_shard, input_shape, x_map, dst_shape
228 )
229 output_tensor_map.insert(0, map_val)
230 cur_axis -= 1
231 cur_size = merged_shape[cur_axis]
232 else:
233 output_tensor_map.insert(0, -1)
235 return output_tensor_map
237 def _handle_sharded_axis(self, merge_tensor_map, cur_axis, x_mesh_shape, shape, dynamic_can_shard,
238 input_shape, x_map, dst_shape):
239 """Handle sharded axis in tensor_map computation."""
240 map_val = merge_tensor_map[cur_axis]
242 if isinstance(map_val, tuple):
243 shard_size = 1
244 for axis in map_val:
245 shard_size *= x_mesh_shape[-axis - 1]
246 else:
247 shard_size = x_mesh_shape[-map_val - 1]
249 if shape < 0:
250 if not dynamic_can_shard:
251 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
252 elif shard_size > shape or shape % shard_size != 0:
253 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
255 return map_val
257 def _apply_partial_status(self, x_layout, out_layout):
258 """Apply partial status from input to output layout."""
259 if x_layout.is_partial():
260 input_partial = x_layout.partial
261 for i, partial_op in enumerate(input_partial):
262 if partial_op is not None and i < len(out_layout.alias_name):
263 out_layout.set_partial_by_dev_axis(out_layout.alias_name[i], partial_op)
265 def infer_layout(self, layouts, extra_args=None):
266 """
267 Infer output layout for reshape operator.
269 For reshape operations, data slice on each device after reshape should be same as data slice before reshape.
271 Args:
272 layouts (Layout): Layout of input x
273 extra_args:
274 For MindSpore Reshape: (destination shape, original shape)
275 For PyTorch reshape/view: (shape_arg1, shape_arg2, ..., original shape) or (shape_tuple, original shape)
277 Returns:
278 tuple: Layout for output tensor
279 """
280 x_layout = layouts[0]
281 x_dict = x_layout.to_dict()
283 dst_shape, input_shape = self._parse_shape_args(extra_args)
284 dst_shape = self._normalize_shape(dst_shape)
286 x_map = _filter_none_split_tensor_map(x_dict["tensor_map"], x_dict["mesh_shape"])
287 x_mesh_shape = x_dict["mesh_shape"]
289 input_shape, dst_shape, dynamic_can_shard = self._handle_dynamic_shape(input_shape, dst_shape)
290 merged_shape, merge_tensor_map = self._merge_unshared_axis(input_shape, x_map)
292 output_tensor_map = self._compute_output_tensor_map(
293 merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard, input_shape, x_map
294 )
296 output_layout = Layout(
297 mesh_shape=x_mesh_shape,
298 alias_name=x_layout.alias_name,
299 rank_list=x_layout.rank_list
300 )
301 output_map, local_dst_shape = self._cal_output_layout_and_dst_shape(output_tensor_map, dst_shape, x_dict)
302 out_layout = output_layout(*output_map)
304 self._apply_partial_status(x_layout, out_layout)
306 return out_layout, local_dst_shape