Coverage for hyper_parallel / core / shard / ops / parallel_reshape.py: 71%

141 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Reshape operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from hyper_parallel.platform import get_platform 

21from .parallel_ops import DistributedOp 

22platform = get_platform() 

23Tensor = platform.Tensor 

24 

25 

26def _filter_none_split_tensor_map(tensor_map, mesh_shape): 

27 """ 

28 Filter out the elements in tensor_map where the size of the corresponding dimension in device_matrix is 1. 

29 

30 Args: 

31 tensor_map (list): A list of tensor mappings, which may contain integers or tuples. 

32 device_matrix (list): A device matrix representing the device distribution across each dimension. 

33 

34 Returns: 

35 list: The filtered list of tensor mappings, where invalid mappings are replaced with -1 or valid mappings are 

36 retained. 

37 """ 

38 filtered_tensor_map = [] 

39 for item in tensor_map: 

40 if isinstance(item, tuple): 

41 filtered = [] 

42 for i in item: 

43 if mesh_shape[-1 - i] != 1: 

44 filtered.append(i) 

45 if len(filtered) == 0: 

46 filtered_tensor_map.append(-1) 

47 elif len(filtered) == 1: 

48 filtered_tensor_map.append(filtered[0]) 

49 else: 

50 filtered_tensor_map.append(tuple(filtered)) 

51 else: 

52 filtered_tensor_map.append(item if mesh_shape[-1 - item] != 1 else -1) 

53 return filtered_tensor_map 

54 

55 

56class ReshapeDistributedOp(DistributedOp): 

57 """Distributed implementation for Reshape operator.""" 

58 

59 def _get_dynamic_shape_info(self, shape): 

60 total_size = 1 

61 dynamic_axis = -1 

62 for axis, s in enumerate(shape): 

63 total_size *= s 

64 if s < 0: 

65 dynamic_axis = axis 

66 return total_size < 0, dynamic_axis, total_size 

67 

68 def _handle_dynamic_shape(self, input_shape, output_shape): 

69 """ 

70 Check dynamic shape. Calculate unknown axis if one of input and output shape is known. If both are unknown, 

71 calculate the relative multiple. 

72 [2, -1, 8], [4, -1, 8] -> [2, -2, 8], [4, -1, 8] 

73 """ 

74 input_shape = list(input_shape) 

75 output_shape = list(output_shape) 

76 is_input_dynamic, input_dynamic_axis, input_total_size = self._get_dynamic_shape_info(input_shape) 

77 is_output_dynamic, output_dynamic_axis, output_total_size = self._get_dynamic_shape_info(output_shape) 

78 dynamic_can_shard = False 

79 if not is_input_dynamic and not is_output_dynamic: 

80 if input_total_size != output_total_size: 

81 raise ValueError(f"The total elements number of input shape {input_shape} and output shape " 

82 f"{output_shape} are different.") 

83 return input_shape, output_shape, dynamic_can_shard 

84 

85 if not is_input_dynamic: 

86 accurate_output_shape = output_shape 

87 accurate_output_shape[output_dynamic_axis] = -input_total_size // output_total_size 

88 return input_shape, accurate_output_shape, dynamic_can_shard 

89 

90 if not is_output_dynamic: 

91 accurate_input_shape = input_shape 

92 accurate_input_shape[input_dynamic_axis] = -output_total_size // input_total_size 

93 return accurate_input_shape, output_shape, dynamic_can_shard 

94 

95 if output_total_size >= input_total_size: 

96 output_shape[output_dynamic_axis] = -(input_total_size // output_total_size) 

97 dynamic_can_shard = True 

98 else: 

99 input_shape[input_dynamic_axis] = -(output_total_size // input_total_size) 

100 return input_shape, output_shape, dynamic_can_shard 

101 

102 def _merge_unshared_axis(self, global_shape, tensor_map): 

103 """ 

104 Merge those axes that are not sharded to the high dimension which is shared. 

105 shape[4, 2, 6, 8], tensor map[-1, -1, 0, -1] -> merged shape[8, 48] 

106 """ 

107 merged_size = 1 

108 merged_shape = [] 

109 merged_tensor_map = [] 

110 for axis in range(len(global_shape) - 1, -1, -1): 

111 merged_size *= global_shape[axis] 

112 if tensor_map[axis] != -1: 

113 merged_shape.insert(0, merged_size) 

114 merged_tensor_map.insert(0, tensor_map[axis]) 

115 merged_size = 1 

116 if tensor_map[0] == -1: 

117 merged_shape.insert(0, merged_size) 

118 merged_tensor_map.insert(0, -1) 

119 return merged_shape, merged_tensor_map 

120 

121 

122 def _cal_output_layout_and_dst_shape(self, output_tensor_map, dst_shape, x_dict): 

123 """ 

124 calculate output layout tensor map and local dst shape. 

125 """ 

126 x_mesh_shape = x_dict["mesh_shape"] 

127 output_map = [] 

128 local_dst_shape = [] 

129 for idx, map_id in enumerate(output_tensor_map): 

130 if isinstance(map_id, tuple): 

131 shard_size = 1 

132 map_idx = [] 

133 for shard_id in map_id: 

134 map_idx.append(x_dict["alias_name"][-1 - shard_id]) 

135 shard_size *= x_mesh_shape[-1 - shard_id] 

136 output_map.append(tuple(map_idx)) 

137 local_dst_shape.append(dst_shape[idx] // shard_size if dst_shape[idx] > 0 else -1) 

138 continue 

139 if map_id < 0: 

140 output_map.append("None") 

141 local_dst_shape.append(dst_shape[idx] if dst_shape[idx] > 0 else -1) 

142 else: 

143 output_map.append(x_dict["alias_name"][-1 - map_id]) 

144 local_dst_shape.append(dst_shape[idx] // x_mesh_shape[-1 - map_id] if dst_shape[idx] > 0 else -1) 

145 return output_map, local_dst_shape 

146 

147 def infer_layout(self, layouts, extra_args): 

148 """ 

149 Infer output layout for reshape operator. 

150 

151 For reshape operations, data slice on each device after reshape should be same as data slice before reshape. 

152 

153 Args: 

154 layouts (Layout): Layout of input x 

155 extra_args: 

156 For MindSpore Reshape: (destination shape, original shape) 

157 For PyTorch reshape/view: (shape_arg1, shape_arg2, ..., original shape) or (shape_tuple, original shape) 

158 

159 Returns: 

160 tuple: Layout for output tensor 

161 """ 

162 x_layout = layouts[0] 

163 x_dict = x_layout.to_dict() 

164 

165 dst_shape = None 

166 input_shape = None 

167 

168 if self.op_name in ["reshape", "view"]: 

169 # PyTorch style: extra_args contains shape args + input_shape (appended by system) 

170 if len(extra_args) < 2: 

171 raise ValueError(f"{self.op_name} requires output shape and input shape.") 

172 

173 input_shape = extra_args[-1] 

174 shape_args = extra_args[:-1] 

175 

176 # Handle variable arguments vs tuple/list argument 

177 if len(shape_args) == 1: 

178 if isinstance(shape_args[0], (list, tuple)): 

179 dst_shape = shape_args[0] 

180 elif isinstance(shape_args[0], Tensor): 

181 dst_shape = shape_args[0].tolist() 

182 else: 

183 # Single int arg (e.g. flatten to 1D) 

184 dst_shape = shape_args 

185 else: 

186 dst_shape = shape_args 

187 else: 

188 # MindSpore Reshape style 

189 if len(extra_args) != 2: 

190 raise ValueError("Reshape requires output shape and input shape.") 

191 

192 dst_shape = extra_args[0] 

193 input_shape = extra_args[1] 

194 

195 # Common processing 

196 if isinstance(dst_shape, Tensor): 

197 dst_shape = dst_shape.tolist() 

198 if not isinstance(dst_shape, list) and not isinstance(dst_shape, tuple): 

199 raise ValueError("Shape should be a tensor or a tuple or a list.") 

200 

201 x_map = _filter_none_split_tensor_map(x_dict["tensor_map"], x_dict["mesh_shape"]) 

202 x_mesh_shape = x_dict["mesh_shape"] 

203 

204 input_shape, dst_shape, dynamic_can_shard = self._handle_dynamic_shape(input_shape, dst_shape) 

205 merged_shape, merge_tensor_map = self._merge_unshared_axis(input_shape, x_map) 

206 

207 output_tensor_map = [] 

208 cur_axis = len(merged_shape) - 1 

209 cur_size = merged_shape[cur_axis] 

210 for shape in reversed(dst_shape): 

211 if cur_size % shape != 0: 

212 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

213 cur_size = cur_size // shape 

214 if cur_size == 1: 

215 if isinstance(merge_tensor_map[cur_axis], tuple): 

216 shard_size = 1 

217 for axis in merge_tensor_map[cur_axis]: 

218 shard_size *= x_mesh_shape[-axis - 1] 

219 else: 

220 shard_size = x_mesh_shape[-merge_tensor_map[cur_axis] - 1] 

221 if shape < 0: 

222 if not dynamic_can_shard: 

223 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

224 elif shard_size > shape or shape % shard_size != 0: 

225 raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") 

226 output_tensor_map.insert(0, merge_tensor_map[cur_axis]) 

227 cur_axis -= 1 

228 cur_size = merged_shape[cur_axis] 

229 else: 

230 output_tensor_map.insert(0, -1) 

231 

232 output_layout = Layout( 

233 mesh_shape=x_mesh_shape, 

234 alias_name=x_layout.alias_name, 

235 rank_list=x_layout.rank_list 

236 ) 

237 output_map, local_dst_shape = self._cal_output_layout_and_dst_shape(output_tensor_map, dst_shape, x_dict) 

238 out_layout = output_layout(*output_map) 

239 return out_layout, local_dst_shape