Coverage for hyper_parallel / core / tensor_redistribution.py: 85%

209 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""tensor_redistribution""" 

16 

17from hyper_parallel.core.dtensor import DTensor 

18from hyper_parallel.core.redistribute_infer import RedistributionOperatorInfer 

19from hyper_parallel.platform import get_platform 

20platform = get_platform() 

21 

22 

23def _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape): 

24 """_construct_layout_tuple_for_transform_operator_list""" 

25 from_layout_dict = from_layout.to_dict() 

26 to_layout_dict = to_layout.to_dict() 

27 from_layout_tuple = (from_layout_dict["mesh_shape"], from_layout_dict["tensor_map"], list(from_full_shape)) 

28 to_layout_tuple = (to_layout_dict["mesh_shape"], to_layout_dict["tensor_map"], list(from_full_shape)) # TODO: 考虑reshape的场景 

29 return from_layout_tuple, to_layout_tuple 

30 

31 

32class TensorRedistribution: 

33 """ 

34 TensorRedistribution. 

35 """ 

36 def __init__(self): 

37 self.is_init = False 

38 self.rank_list = None # rank_list for current stage 

39 self.rank_id = None # current rank_lid 

40 self._transform_cache = {} 

41 self._construct_op_operator = { 

42 "Reshape": self._construct_reshape, 

43 "AllConcat": self._construct_all_concat, 

44 "StridedSlice": self._construct_strided_slice, 

45 "all_concat": self._construct_all_concat_new, 

46 "all_split": self._construct_all_split, 

47 "all_to_all": self._construct_all_to_all 

48 } 

49 

50 def _construct_reshape(self, x, *args): 

51 """args: (*shape)""" 

52 return x.view(args) 

53 

54 def _construct_all_concat(self, x, *args): 

55 """args: (*rank_list, concat_dim)""" 

56 rank_list = args[0:-1] 

57 concat_dim = args[-1] 

58 group = platform.create_group(rank_list) 

59 concat_size = len(rank_list) 

60 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim) 

61 

62 

63 def _construct_strided_slice(self, x, *args): 

64 """args: (begin, end, strides)""" 

65 dims = len(args) // 3 

66 return platform.construct_strided_slice(x, args[0: dims], args[dims: 2 * dims], args[2 * dims:]) 

67 

68 def _construct_all_concat_new(self, x, *args): 

69 """args: (concat_dim, concat_size, group)""" 

70 rank_list = args[2] 

71 concat_dim = args[0] 

72 concat_size = args[1] 

73 group = platform.create_group(rank_list) 

74 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim) 

75 

76 def _construct_all_split(self, x, *args): 

77 """args: (split_dim, split_size, group)""" 

78 rank_list = list(args[2]) 

79 split_dim = args[0] 

80 split_size = args[1] 

81 idx = rank_list.index(self.rank_id) 

82 return platform.chunk(x, split_dim, split_size, idx) 

83 

84 def _construct_all_to_all(self, x, *args): 

85 """args: (split_dim, concat_dim, permute_size, group)""" 

86 split_dim, concat_dim, split_count, rank_list = args 

87 group = platform.create_group(rank_list) 

88 original_shape = x.shape 

89 

90 dim_size = original_shape[split_dim] 

91 if dim_size % split_count != 0: 

92 raise ValueError(f"Dimension {split_dim} with size {dim_size} " 

93 f"cannot be evenly split into {split_count} parts") 

94 

95 split_size = dim_size // split_count 

96 final_shape = list(original_shape) 

97 if split_dim != concat_dim: 

98 final_shape[split_dim] = split_size 

99 final_shape[concat_dim] = final_shape[concat_dim] * split_count 

100 final_shape = tuple(final_shape) 

101 

102 pre_special_handle = all(original_shape[i] == 1 for i in range(split_dim)) 

103 if pre_special_handle: 

104 reshape_shape = (split_count * split_size,) + original_shape[split_dim + 1:] 

105 x_reshaped = x.view(reshape_shape) 

106 else: 

107 reshape_dims = list(original_shape) 

108 reshape_dims[split_dim] = split_count 

109 reshape_dims.insert(split_dim + 1, split_size) 

110 

111 trans_dims = list(range(len(reshape_dims))) 

112 trans_dims.remove(split_dim) 

113 trans_dims.insert(0, split_dim) 

114 

115 x_reshaped = x.reshape(reshape_dims).permute(trans_dims).contiguous() 

116 

117 reshape_shape = list(x_reshaped.shape) 

118 reshape_shape[0] = reshape_shape[0] * reshape_shape[1] 

119 reshape_shape.pop(1) 

120 reshape_shape = tuple(reshape_shape) 

121 x_reshaped = x_reshaped.reshape(reshape_shape) 

122 x_reshaped = x_reshaped.contiguous() 

123 output_tensor = platform.differentiable_all_to_all( 

124 input_data=x_reshaped, 

125 output_shape=reshape_shape, 

126 group=group 

127 ) 

128 

129 post_special_handle = all(final_shape[i] == 1 for i in range(concat_dim)) 

130 if post_special_handle: 

131 return output_tensor.view(final_shape) 

132 

133 output_reshape = list(output_tensor.shape) 

134 output_reshape[0] = split_count 

135 output_reshape.insert(1, output_tensor.shape[0] // split_count) 

136 

137 out_trans_dims = list(range(len(output_reshape))) 

138 first_dim = out_trans_dims.pop(0) 

139 if concat_dim >= len(out_trans_dims): 

140 out_trans_dims.append(first_dim) 

141 else: 

142 out_trans_dims.insert(concat_dim, first_dim) 

143 

144 final_output = output_tensor.reshape(output_reshape).permute(out_trans_dims).contiguous() 

145 

146 final_reshape = list(final_output.shape) 

147 if concat_dim < len(final_reshape) - 1: 

148 final_reshape[concat_dim] = final_reshape[concat_dim] * final_reshape[concat_dim + 1] 

149 final_reshape.pop(concat_dim + 1) 

150 

151 return final_output.reshape(final_reshape) 

152 

153 def _apply_eazy_redistribute(self, src_layout, dst_layout): 

154 """_apply_eazy_redistribute""" 

155 if (src_layout.mesh_shape != dst_layout.mesh_shape or 

156 src_layout.rank_list != dst_layout.rank_list): 

157 return False 

158 

159 tensor_map_size = len(src_layout.tensor_map) 

160 if len(dst_layout.tensor_map) != tensor_map_size: 

161 return False 

162 return True 

163 

164 def _redistribution_without_shape(self, local_x, src_layout, dst_layout, key): 

165 """_redistribution_without_shape""" 

166 inferrer = RedistributionOperatorInfer( 

167 dev_mat=src_layout.mesh_shape, 

168 in_tensor_map=list(src_layout.tensor_map), 

169 out_tensor_map=list(dst_layout.tensor_map) 

170 ) 

171 op_list = inferrer.InferOpsList(self.rank_id, self.rank_list) 

172 self._transform_cache[key] = op_list 

173 for op in op_list: 

174 local_x = self._construct_op_operator[op[0]](local_x, *op[1]) 

175 return local_x 

176 

177 def redistribution(self, input_x, to_layout): 

178 """ tensor redistribution """ 

179 x_layout = input_x.layout 

180 x = input_x 

181 if input_x.layout.is_partial(): 

182 # Solve partial status first 

183 if input_x.layout.mesh_shape == to_layout.mesh_shape: 

184 x = self.reduce_partial(input_x, to_layout) 

185 else: 

186 x = self.reduce_partial(input_x, x_layout) 

187 

188 from_layout = x.layout 

189 if not self.is_init: 

190 self.rank_id = platform.get_rank() 

191 self.rank_list = from_layout.rank_list 

192 self.is_init = True 

193 if self.rank_list != to_layout.rank_list: 

194 raise ValueError(f"The from_layout rank list: {self.rank_list} is not equal to " 

195 f"to_layout rank list: {to_layout.rank_list}") 

196 key = from_layout.compact_str + to_layout.compact_str + str(self.rank_id) 

197 if key in self._transform_cache: 

198 x = x.to_local() 

199 transform_operator_list = self._transform_cache[key] 

200 for transform_operator in transform_operator_list: 

201 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1]) 

202 return DTensor.from_local(x, to_layout.mesh, to_layout.placements) 

203 

204 full_shape = x.shape 

205 key_and_shape = key + str(full_shape) 

206 x = x.to_local() 

207 if key_and_shape in self._transform_cache: 

208 transform_operator_list = self._transform_cache[key_and_shape] 

209 for transform_operator in transform_operator_list: 

210 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1]) 

211 return DTensor.from_local(x, to_layout.mesh, to_layout.placements) 

212 

213 if self._apply_eazy_redistribute(from_layout, to_layout): 

214 if from_layout.is_partial: 

215 from_layout.reset_partial() 

216 x = self._redistribution_without_shape(x, from_layout, to_layout, key) 

217 else: 

218 transform_operator_list = self._infer_transform_operator_list(from_layout, to_layout, 

219 full_shape, key_and_shape) 

220 for transform_operator in transform_operator_list: 

221 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1]) 

222 return DTensor.from_local(x, to_layout.mesh, to_layout.placements) 

223 

224 def _infer_transform_operator_list(self, from_layout, to_layout, from_full_shape, key): 

225 """infer transform operator list""" 

226 from_layout_tuple, to_layout_tuple = \ 

227 _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape) 

228 self._transform_cache[key] = \ 

229 platform.get_tensor_transform().transform_tensor_sharding(from_layout_tuple, to_layout_tuple, 

230 self.rank_list, False, self.rank_id) 

231 return self._transform_cache[key] 

232 

233 def _allreduce_along_dev_dim(self, x, op, layout, dev_dim): 

234 """Do allreduce at specified axis along dev_dim.""" 

235 group = layout.get_comm_group_by_axis(dev_dim) 

236 zero_dim = x.dim() == 0 

237 if zero_dim: 

238 x = x.unsqueeze(0) 

239 if op == 'avg': 

240 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)] 

241 x = platform.differentiable_all_reduce(x, 'sum', group) 

242 x = x / dev_num 

243 elif op == 'all': 

244 x_int32 = platform.tensor_type_cast(x.bool(), 'int32') # True→1, False→0 

245 x = platform.differentiable_all_reduce(x_int32, 'all', group) 

246 x = x.bool() 

247 else: 

248 x = platform.differentiable_all_reduce(x, op, group) 

249 if zero_dim: 

250 x = x.squeeze(0) 

251 return x 

252 

253 def _reduce_scatter_along_dev_dim_with_axis(self, x, axis, op, layout, dev_dim): 

254 """Do reduce_scatter at specified axis along dev_dim.""" 

255 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)] 

256 group = layout.get_comm_group_by_axis(dev_dim) 

257 output_tensor = self.platform.reduce_scatter(x, dev_num, axis, op, group) 

258 return output_tensor 

259 

260 def reduce_partial(self, input_x, to_layout): 

261 """Reduce partial status.""" 

262 from_layout = input_x.layout 

263 x = input_x 

264 if from_layout is None or not from_layout.is_partial: 

265 return x 

266 

267 x = x.to_local() 

268 if from_layout.mesh_shape != to_layout.mesh_shape: 

269 raise ValueError(f"For reduce partial, mesh_shape between from_layout and to_layout must be the same, " 

270 f"but got {from_layout.mesh_shape} and {to_layout.mesh_shape}") 

271 if to_layout.is_partial(): 

272 raise ValueError(f"For reduce partial, to_layout must be non-partial status, but got to_layout.partial: " 

273 f"{to_layout.partial}") 

274 

275 dev_map_order = {} 

276 for dev_axis in to_layout.alias_tensor_map: 

277 if isinstance(dev_axis, tuple): 

278 for i, sub_dev_axis in enumerate(dev_axis): 

279 dev_map_order[sub_dev_axis] = i 

280 else: 

281 dev_map_order[dev_axis] = 0 

282 

283 pending_reduce_op_list = [] # List[Tuple[comm_op, op, dev_dim, reduce_dim]] 

284 for dev_axis_index, op in enumerate(from_layout.partial): 

285 if op is None: 

286 continue 

287 dev_axis = from_layout.alias_name[dev_axis_index] 

288 apply_shard_dim = to_layout.get_dev_axis_apply_shard_axis(dev_axis) 

289 comm_op = "ReduceScatter" if apply_shard_dim else "AllReduce" 

290 pending_reduce_op_list.append((comm_op, op, dev_axis, apply_shard_dim)) 

291 

292 # sort reduce op 

293 # 1. ReduceScatter is executed before AllReduce 

294 # 2. If multiple split, the dev axis split outer will be execute first. 

295 # e.g ("cp", "tp"), will execute reduce_scatter along "cp" before "tp" 

296 # 3. Lower dev_id execute before higher dev_id 

297 sorted_pending_reduce_op_list = \ 

298 sorted(pending_reduce_op_list, key=lambda reduce_pair: (reduce_pair[0] != "ReduceScatter", 

299 dev_map_order.get(reduce_pair[2], 0), 

300 to_layout.mesh.axis_id(reduce_pair[2]))) 

301 

302 output_alias_tensor_map = list(from_layout.alias_tensor_map) 

303 for reduce_op_pair in sorted_pending_reduce_op_list: 

304 comm_op = reduce_op_pair[0] 

305 op = reduce_op_pair[1] 

306 dev_axis = reduce_op_pair[2] 

307 if comm_op == "AllReduce": 

308 x = self._allreduce_along_dev_dim(x, op, from_layout, dev_axis) 

309 elif comm_op == "ReduceScatter": 

310 reduce_axis = reduce_op_pair[3] 

311 x = self._reduce_scatter_along_dev_dim_with_axis(x, reduce_axis, op, from_layout, dev_axis) 

312 if output_alias_tensor_map[reduce_axis] == "None": 

313 output_alias_tensor_map[reduce_axis] = dev_axis 

314 elif isinstance(output_alias_tensor_map[reduce_axis], tuple): 

315 output_alias_tensor_map[reduce_axis] += (dev_axis,) 

316 else: 

317 output_alias_tensor_map[reduce_axis] = (output_alias_tensor_map[reduce_axis], dev_axis) 

318 

319 output_layout = from_layout(*output_alias_tensor_map) 

320 output_layout.reset_partial() 

321 return DTensor.from_local(x, output_layout.mesh, output_layout.placements) 

322 

323 

324_tensor_redistribution = TensorRedistribution()