Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_reshape.py: 82%

166 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Reshape operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from hyper_parallel.platform import get_platform 

21from .parallel_ops import DistributedOp 

22platform = get_platform() 

23Tensor = platform.Tensor 

24 

25 

26def _filter_none_split_tensor_map(tensor_map, mesh_shape): 

27 """ 

28 Filter out the elements in tensor_map where the size of the corresponding dimension in device_matrix is 1. 

29 

30 Args: 

31 tensor_map (list): A list of tensor mappings, which may contain integers or tuples. 

32 device_matrix (list): A device matrix representing the device distribution across each dimension. 

33 

34 Returns: 

35 list: The filtered list of tensor mappings, where invalid mappings are replaced with -1 or valid mappings are 

36 retained. 

37 """ 

38 filtered_tensor_map = [] 

39 for item in tensor_map: 

40 if isinstance(item, tuple): 

41 filtered = [] 

42 for i in item: 

43 if mesh_shape[-1 - i] != 1: 

44 filtered.append(i) 

45 if len(filtered) == 0: 

46 filtered_tensor_map.append(-1) 

47 elif len(filtered) == 1: 

48 filtered_tensor_map.append(filtered[0]) 

49 else: 

50 filtered_tensor_map.append(tuple(filtered)) 

51 else: 

52 filtered_tensor_map.append(item if mesh_shape[-1 - item] != 1 else -1) 

53 return filtered_tensor_map 

54 

55 

56class ReshapeDistributedOp(DistributedOp): 

57 """Distributed implementation for Reshape operator.""" 

58 

59 def __init__(self, op_name): 

60 super().__init__(op_name) 

61 self._allow_partial_inputs = True 

62 

63 def _get_dynamic_shape_info(self, shape): 

64 total_size = 1 

65 dynamic_axis = -1 

66 for axis, s in enumerate(shape): 

67 total_size *= s 

68 if s < 0: 

69 dynamic_axis = axis 

70 return total_size < 0, dynamic_axis, total_size 

71 

72 def _handle_dynamic_shape(self, input_shape, output_shape): 

73 """ 

74 Check dynamic shape. Calculate unknown axis if one of input and output shape is known. If both are unknown, 

75 calculate the relative multiple. 

76 [2, -1, 8], [4, -1, 8] -> [2, -2, 8], [4, -1, 8] 

77 """ 

78 input_shape = list(input_shape) 

79 output_shape = list(output_shape) 

80 is_input_dynamic, input_dynamic_axis, input_total_size = self._get_dynamic_shape_info(input_shape) 

81 is_output_dynamic, output_dynamic_axis, output_total_size = self._get_dynamic_shape_info(output_shape) 

82 dynamic_can_shard = False 

83 if not is_input_dynamic and not is_output_dynamic: 

84 if input_total_size != output_total_size: 

85 raise ValueError(f"The total elements number of input shape {input_shape} and output shape " 

86 f"{output_shape} are different.") 

87 return input_shape, output_shape, dynamic_can_shard 

88 

89 if not is_input_dynamic: 

90 accurate_output_shape = output_shape 

91 accurate_output_shape[output_dynamic_axis] = -input_total_size // output_total_size 

92 return input_shape, accurate_output_shape, dynamic_can_shard 

93 

94 if not is_output_dynamic: 

95 accurate_input_shape = input_shape 

96 accurate_input_shape[input_dynamic_axis] = -output_total_size // input_total_size 

97 return accurate_input_shape, output_shape, dynamic_can_shard 

98 

99 if output_total_size >= input_total_size: 

100 output_shape[output_dynamic_axis] = -(input_total_size // output_total_size) 

101 dynamic_can_shard = True 

102 else: 

103 input_shape[input_dynamic_axis] = -(output_total_size // input_total_size) 

104 return input_shape, output_shape, dynamic_can_shard 

105 

106 def _merge_unshared_axis(self, global_shape, tensor_map): 

107 """ 

108 Merge those axes that are not sharded to the high dimension which is shared. 

109 shape[4, 2, 6, 8], tensor map[-1, -1, 0, -1] -> merged shape[8, 48] 

110 """ 

111 merged_size = 1 

112 merged_shape = [] 

113 merged_tensor_map = [] 

114 for axis in range(len(global_shape) - 1, -1, -1): 

115 merged_size *= global_shape[axis] 

116 if tensor_map[axis] != -1: 

117 merged_shape.insert(0, merged_size) 

118 merged_tensor_map.insert(0, tensor_map[axis]) 

119 merged_size = 1 

120 if tensor_map[0] == -1: 

121 merged_shape.insert(0, merged_size) 

122 merged_tensor_map.insert(0, -1) 

123 return merged_shape, merged_tensor_map 

124 

125 

126 def _cal_output_layout_and_dst_shape(self, output_tensor_map, dst_shape, x_dict): 

127 """ 

128 calculate output layout tensor map and local dst shape. 

129 """ 

130 x_mesh_shape = x_dict["mesh_shape"] 

131 output_map = [] 

132 local_dst_shape = [] 

133 for idx, map_id in enumerate(output_tensor_map): 

134 if isinstance(map_id, tuple): 

135 shard_size = 1 

136 map_idx = [] 

137 for shard_id in map_id: 

138 map_idx.append(x_dict["alias_name"][-1 - shard_id]) 

139 shard_size *= x_mesh_shape[-1 - shard_id] 

140 output_map.append(tuple(map_idx)) 

141 local_dst_shape.append(dst_shape[idx] // shard_size if dst_shape[idx] > 0 else -1) 

142 continue 

143 if map_id < 0: 

144 output_map.append("None") 

145 local_dst_shape.append(dst_shape[idx] if dst_shape[idx] > 0 else -1) 

146 else: 

147 output_map.append(x_dict["alias_name"][-1 - map_id]) 

148 local_dst_shape.append(dst_shape[idx] // x_mesh_shape[-1 - map_id] if dst_shape[idx] > 0 else -1) 

149 return output_map, local_dst_shape 

150 

151 def _parse_shape_args(self, extra_args): 

152 """Parse shape arguments from extra_args. 

153 

154 Args: 

155 extra_args: Extra arguments containing shape info 

156 

157 Returns: 

158 tuple: (dst_shape, input_shape) 

159 """ 

160 if self.op_name in ["reshape", "view"]: 

161 return self._parse_torch_shape_args(extra_args) 

162 return self._parse_mindspore_shape_args(extra_args) 

163 

164 def _parse_torch_shape_args(self, extra_args): 

165 """Parse PyTorch style shape arguments.""" 

166 if len(extra_args) < 2: 

167 raise ValueError(f"{self.op_name} requires output shape and input shape.") 

168 

169 input_shape = extra_args[-1] 

170 shape_args = extra_args[:-1] 

171 

172 if len(shape_args) == 1: 

173 first_arg = shape_args[0] 

174 if isinstance(first_arg, (list, tuple)): 

175 dst_shape = first_arg 

176 elif isinstance(first_arg, Tensor): 

177 dst_shape = first_arg.tolist() 

178 else: 

179 dst_shape = shape_args 

180 else: 

181 dst_shape = shape_args 

182 

183 return dst_shape, input_shape 

184 

185 def _parse_mindspore_shape_args(self, extra_args): 

186 """Parse MindSpore style shape arguments.""" 

187 if len(extra_args) != 2: 

188 raise ValueError("Reshape requires output shape and input shape.") 

189 

190 return extra_args[0], extra_args[1] 

191 

192 def _normalize_shape(self, dst_shape): 

193 """Normalize dst_shape to list format.""" 

194 if isinstance(dst_shape, Tensor): 

195 dst_shape = dst_shape.tolist() 

196 if not isinstance(dst_shape, (list, tuple)): 

197 raise ValueError("Shape should be a tensor or a tuple or a list.") 

198 return dst_shape 

199 

200 def _compute_output_tensor_map(self, merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard, 

201 input_shape, x_map): 

202 """Compute output tensor_map from merged information. 

203 

204 Args: 

205 merged_shape: Merged shape from _merge_unshared_axis 

206 merge_tensor_map: Merged tensor_map from _merge_unshared_axis 

207 dst_shape: Target shape 

208 x_mesh_shape: Mesh shape 

209 dynamic_can_shard: Whether dynamic shape can be sharded 

210 input_shape: Original input shape 

211 x_map: Input tensor_map 

212 

213 Returns: 

214 list: Output tensor_map 

215 """ 

216 output_tensor_map = [] 

217 cur_axis = len(merged_shape) - 1 

218 cur_size = merged_shape[cur_axis] 

219 

220 for shape in reversed(dst_shape): 

221 if cur_size % shape != 0: 

222 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

223 cur_size = cur_size // shape 

224 

225 if cur_size == 1: 

226 map_val = self._handle_sharded_axis( 

227 merge_tensor_map, cur_axis, x_mesh_shape, shape, dynamic_can_shard, input_shape, x_map, dst_shape 

228 ) 

229 output_tensor_map.insert(0, map_val) 

230 cur_axis -= 1 

231 cur_size = merged_shape[cur_axis] 

232 else: 

233 output_tensor_map.insert(0, -1) 

234 

235 return output_tensor_map 

236 

237 def _handle_sharded_axis(self, merge_tensor_map, cur_axis, x_mesh_shape, shape, dynamic_can_shard, 

238 input_shape, x_map, dst_shape): 

239 """Handle sharded axis in tensor_map computation.""" 

240 map_val = merge_tensor_map[cur_axis] 

241 

242 if isinstance(map_val, tuple): 

243 shard_size = 1 

244 for axis in map_val: 

245 shard_size *= x_mesh_shape[-axis - 1] 

246 else: 

247 shard_size = x_mesh_shape[-map_val - 1] 

248 

249 if shape < 0: 

250 if not dynamic_can_shard: 

251 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

252 elif shard_size > shape or shape % shard_size != 0: 

253 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

254 

255 return map_val 

256 

257 def _apply_partial_status(self, x_layout, out_layout): 

258 """Apply partial status from input to output layout.""" 

259 if x_layout.is_partial(): 

260 input_partial = x_layout.partial 

261 for i, partial_op in enumerate(input_partial): 

262 if partial_op is not None and i < len(out_layout.alias_name): 

263 out_layout.set_partial_by_dev_axis(out_layout.alias_name[i], partial_op) 

264 

265 def infer_layout(self, layouts, extra_args=None): 

266 """ 

267 Infer output layout for reshape operator. 

268 

269 For reshape operations, data slice on each device after reshape should be same as data slice before reshape. 

270 

271 Args: 

272 layouts (Layout): Layout of input x 

273 extra_args: 

274 For MindSpore Reshape: (destination shape, original shape) 

275 For PyTorch reshape/view: (shape_arg1, shape_arg2, ..., original shape) or (shape_tuple, original shape) 

276 

277 Returns: 

278 tuple: Layout for output tensor 

279 """ 

280 x_layout = layouts[0] 

281 x_dict = x_layout.to_dict() 

282 

283 dst_shape, input_shape = self._parse_shape_args(extra_args) 

284 dst_shape = self._normalize_shape(dst_shape) 

285 

286 x_map = _filter_none_split_tensor_map(x_dict["tensor_map"], x_dict["mesh_shape"]) 

287 x_mesh_shape = x_dict["mesh_shape"] 

288 

289 input_shape, dst_shape, dynamic_can_shard = self._handle_dynamic_shape(input_shape, dst_shape) 

290 merged_shape, merge_tensor_map = self._merge_unshared_axis(input_shape, x_map) 

291 

292 output_tensor_map = self._compute_output_tensor_map( 

293 merged_shape, merge_tensor_map, dst_shape, x_mesh_shape, dynamic_can_shard, input_shape, x_map 

294 ) 

295 

296 output_layout = Layout( 

297 mesh_shape=x_mesh_shape, 

298 alias_name=x_layout.alias_name, 

299 rank_list=x_layout.rank_list 

300 ) 

301 output_map, local_dst_shape = self._cal_output_layout_and_dst_shape(output_tensor_map, dst_shape, x_dict) 

302 out_layout = output_layout(*output_map) 

303 

304 self._apply_partial_status(x_layout, out_layout) 

305 

306 return out_layout, local_dst_shape