Coverage for hyper_parallel / core / shard / ops / parallel_reshape.py: 71%
141 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Reshape operator.
17"""
19from hyper_parallel.core.layout import Layout
20from hyper_parallel.platform import get_platform
21from .parallel_ops import DistributedOp
22platform = get_platform()
23Tensor = platform.Tensor
26def _filter_none_split_tensor_map(tensor_map, mesh_shape):
27 """
28 Filter out the elements in tensor_map where the size of the corresponding dimension in device_matrix is 1.
30 Args:
31 tensor_map (list): A list of tensor mappings, which may contain integers or tuples.
32 device_matrix (list): A device matrix representing the device distribution across each dimension.
34 Returns:
35 list: The filtered list of tensor mappings, where invalid mappings are replaced with -1 or valid mappings are
36 retained.
37 """
38 filtered_tensor_map = []
39 for item in tensor_map:
40 if isinstance(item, tuple):
41 filtered = []
42 for i in item:
43 if mesh_shape[-1 - i] != 1:
44 filtered.append(i)
45 if len(filtered) == 0:
46 filtered_tensor_map.append(-1)
47 elif len(filtered) == 1:
48 filtered_tensor_map.append(filtered[0])
49 else:
50 filtered_tensor_map.append(tuple(filtered))
51 else:
52 filtered_tensor_map.append(item if mesh_shape[-1 - item] != 1 else -1)
53 return filtered_tensor_map
56class ReshapeDistributedOp(DistributedOp):
57 """Distributed implementation for Reshape operator."""
59 def _get_dynamic_shape_info(self, shape):
60 total_size = 1
61 dynamic_axis = -1
62 for axis, s in enumerate(shape):
63 total_size *= s
64 if s < 0:
65 dynamic_axis = axis
66 return total_size < 0, dynamic_axis, total_size
68 def _handle_dynamic_shape(self, input_shape, output_shape):
69 """
70 Check dynamic shape. Calculate unknown axis if one of input and output shape is known. If both are unknown,
71 calculate the relative multiple.
72 [2, -1, 8], [4, -1, 8] -> [2, -2, 8], [4, -1, 8]
73 """
74 input_shape = list(input_shape)
75 output_shape = list(output_shape)
76 is_input_dynamic, input_dynamic_axis, input_total_size = self._get_dynamic_shape_info(input_shape)
77 is_output_dynamic, output_dynamic_axis, output_total_size = self._get_dynamic_shape_info(output_shape)
78 dynamic_can_shard = False
79 if not is_input_dynamic and not is_output_dynamic:
80 if input_total_size != output_total_size:
81 raise ValueError(f"The total elements number of input shape {input_shape} and output shape "
82 f"{output_shape} are different.")
83 return input_shape, output_shape, dynamic_can_shard
85 if not is_input_dynamic:
86 accurate_output_shape = output_shape
87 accurate_output_shape[output_dynamic_axis] = -input_total_size // output_total_size
88 return input_shape, accurate_output_shape, dynamic_can_shard
90 if not is_output_dynamic:
91 accurate_input_shape = input_shape
92 accurate_input_shape[input_dynamic_axis] = -output_total_size // input_total_size
93 return accurate_input_shape, output_shape, dynamic_can_shard
95 if output_total_size >= input_total_size:
96 output_shape[output_dynamic_axis] = -(input_total_size // output_total_size)
97 dynamic_can_shard = True
98 else:
99 input_shape[input_dynamic_axis] = -(output_total_size // input_total_size)
100 return input_shape, output_shape, dynamic_can_shard
102 def _merge_unshared_axis(self, global_shape, tensor_map):
103 """
104 Merge those axes that are not sharded to the high dimension which is shared.
105 shape[4, 2, 6, 8], tensor map[-1, -1, 0, -1] -> merged shape[8, 48]
106 """
107 merged_size = 1
108 merged_shape = []
109 merged_tensor_map = []
110 for axis in range(len(global_shape) - 1, -1, -1):
111 merged_size *= global_shape[axis]
112 if tensor_map[axis] != -1:
113 merged_shape.insert(0, merged_size)
114 merged_tensor_map.insert(0, tensor_map[axis])
115 merged_size = 1
116 if tensor_map[0] == -1:
117 merged_shape.insert(0, merged_size)
118 merged_tensor_map.insert(0, -1)
119 return merged_shape, merged_tensor_map
122 def _cal_output_layout_and_dst_shape(self, output_tensor_map, dst_shape, x_dict):
123 """
124 calculate output layout tensor map and local dst shape.
125 """
126 x_mesh_shape = x_dict["mesh_shape"]
127 output_map = []
128 local_dst_shape = []
129 for idx, map_id in enumerate(output_tensor_map):
130 if isinstance(map_id, tuple):
131 shard_size = 1
132 map_idx = []
133 for shard_id in map_id:
134 map_idx.append(x_dict["alias_name"][-1 - shard_id])
135 shard_size *= x_mesh_shape[-1 - shard_id]
136 output_map.append(tuple(map_idx))
137 local_dst_shape.append(dst_shape[idx] // shard_size if dst_shape[idx] > 0 else -1)
138 continue
139 if map_id < 0:
140 output_map.append("None")
141 local_dst_shape.append(dst_shape[idx] if dst_shape[idx] > 0 else -1)
142 else:
143 output_map.append(x_dict["alias_name"][-1 - map_id])
144 local_dst_shape.append(dst_shape[idx] // x_mesh_shape[-1 - map_id] if dst_shape[idx] > 0 else -1)
145 return output_map, local_dst_shape
147 def infer_layout(self, layouts, extra_args):
148 """
149 Infer output layout for reshape operator.
151 For reshape operations, data slice on each device after reshape should be same as data slice before reshape.
153 Args:
154 layouts (Layout): Layout of input x
155 extra_args:
156 For MindSpore Reshape: (destination shape, original shape)
157 For PyTorch reshape/view: (shape_arg1, shape_arg2, ..., original shape) or (shape_tuple, original shape)
159 Returns:
160 tuple: Layout for output tensor
161 """
162 x_layout = layouts[0]
163 x_dict = x_layout.to_dict()
165 dst_shape = None
166 input_shape = None
168 if self.op_name in ["reshape", "view"]:
169 # PyTorch style: extra_args contains shape args + input_shape (appended by system)
170 if len(extra_args) < 2:
171 raise ValueError(f"{self.op_name} requires output shape and input shape.")
173 input_shape = extra_args[-1]
174 shape_args = extra_args[:-1]
176 # Handle variable arguments vs tuple/list argument
177 if len(shape_args) == 1:
178 if isinstance(shape_args[0], (list, tuple)):
179 dst_shape = shape_args[0]
180 elif isinstance(shape_args[0], Tensor):
181 dst_shape = shape_args[0].tolist()
182 else:
183 # Single int arg (e.g. flatten to 1D)
184 dst_shape = shape_args
185 else:
186 dst_shape = shape_args
187 else:
188 # MindSpore Reshape style
189 if len(extra_args) != 2:
190 raise ValueError("Reshape requires output shape and input shape.")
192 dst_shape = extra_args[0]
193 input_shape = extra_args[1]
195 # Common processing
196 if isinstance(dst_shape, Tensor):
197 dst_shape = dst_shape.tolist()
198 if not isinstance(dst_shape, list) and not isinstance(dst_shape, tuple):
199 raise ValueError("Shape should be a tensor or a tuple or a list.")
201 x_map = _filter_none_split_tensor_map(x_dict["tensor_map"], x_dict["mesh_shape"])
202 x_mesh_shape = x_dict["mesh_shape"]
204 input_shape, dst_shape, dynamic_can_shard = self._handle_dynamic_shape(input_shape, dst_shape)
205 merged_shape, merge_tensor_map = self._merge_unshared_axis(input_shape, x_map)
207 output_tensor_map = []
208 cur_axis = len(merged_shape) - 1
209 cur_size = merged_shape[cur_axis]
210 for shape in reversed(dst_shape):
211 if cur_size % shape != 0:
212 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
213 cur_size = cur_size // shape
214 if cur_size == 1:
215 if isinstance(merge_tensor_map[cur_axis], tuple):
216 shard_size = 1
217 for axis in merge_tensor_map[cur_axis]:
218 shard_size *= x_mesh_shape[-axis - 1]
219 else:
220 shard_size = x_mesh_shape[-merge_tensor_map[cur_axis] - 1]
221 if shape < 0:
222 if not dynamic_can_shard:
223 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
224 elif shard_size > shape or shape % shard_size != 0:
225 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}")
226 output_tensor_map.insert(0, merge_tensor_map[cur_axis])
227 cur_axis -= 1
228 cur_size = merged_shape[cur_axis]
229 else:
230 output_tensor_map.insert(0, -1)
232 output_layout = Layout(
233 mesh_shape=x_mesh_shape,
234 alias_name=x_layout.alias_name,
235 rank_list=x_layout.rank_list
236 )
237 output_map, local_dst_shape = self._cal_output_layout_and_dst_shape(output_tensor_map, dst_shape, x_dict)
238 out_layout = output_layout(*output_map)
239 return out_layout, local_dst_shape