Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / tensor_redistribution.py: 48%

214 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""tensor_redistribution""" 

16 

17from hyper_parallel.core.dtensor.dtensor import DTensor 

18from hyper_parallel.core.dtensor.redistribute_infer import RedistributionOperatorInfer 

19from hyper_parallel.platform import get_platform 

20platform = get_platform() 

21 

22 

23def _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape): 

24 """_construct_layout_tuple_for_transform_operator_list""" 

25 from_layout_dict = from_layout.to_dict() 

26 to_layout_dict = to_layout.to_dict() 

27 from_layout_tuple = ( 

28 from_layout_dict["mesh_shape"], from_layout_dict["tensor_map"], list(from_full_shape) 

29 ) 

30 # NOTE: consider reshape scenario when to_full_shape differs from from_full_shape 

31 to_layout_tuple = ( 

32 to_layout_dict["mesh_shape"], to_layout_dict["tensor_map"], list(from_full_shape) 

33 ) 

34 return from_layout_tuple, to_layout_tuple 

35 

36 

37class TensorRedistribution: 

38 """ 

39 TensorRedistribution. 

40 """ 

41 def __init__(self): 

42 self.is_init = False 

43 self.rank_id = None # current rank_id (global) 

44 self._transform_cache = {} 

45 self._construct_op_operator = { 

46 "Reshape": self._construct_reshape, 

47 "AllConcat": self._construct_all_concat, 

48 "StridedSlice": self._construct_strided_slice, 

49 "all_concat": TensorRedistribution._construct_all_concat_new, 

50 "all_split": self._construct_all_split, 

51 "all_to_all": self._construct_all_to_all 

52 } 

53 

54 def _construct_reshape(self, x, *args): 

55 """args: (*shape)""" 

56 return x.view(args) 

57 

58 def _construct_all_concat(self, x, *args): 

59 """args: (*rank_list, concat_dim)""" 

60 rank_list = args[0:-1] 

61 concat_dim = args[-1] 

62 group = platform.create_group(rank_list) 

63 concat_size = len(rank_list) 

64 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim) 

65 

66 

67 def _construct_strided_slice(self, x, *args): 

68 """args: (begin, end, strides)""" 

69 dims = len(args) // 3 

70 return platform.construct_strided_slice(x, args[0: dims], args[dims: 2 * dims], args[2 * dims:]) 

71 

72 @staticmethod 

73 def _construct_all_concat_new(x, *args): 

74 """args: (concat_dim, concat_size, group)""" 

75 rank_list = args[2] 

76 concat_dim = args[0] 

77 concat_size = args[1] 

78 group = platform.create_group(rank_list) 

79 return platform.differentiable_all_gather_concat(x, group, concat_size, concat_dim) 

80 

81 def _construct_all_split(self, x, *args): 

82 """args: (split_dim, split_size, group)""" 

83 rank_list = list(args[2]) 

84 split_dim = args[0] 

85 split_size = args[1] 

86 idx = rank_list.index(self.rank_id) 

87 return platform.chunk(x, split_dim, split_size, idx) 

88 

89 def _construct_all_to_all(self, x, *args): 

90 """args: (split_dim, concat_dim, permute_size, group)""" 

91 split_dim, concat_dim, split_count, rank_list = args 

92 group = platform.create_group(rank_list) 

93 original_shape = x.shape 

94 

95 dim_size = original_shape[split_dim] 

96 if dim_size % split_count != 0: 

97 raise ValueError(f"Dimension {split_dim} with size {dim_size} " 

98 f"cannot be evenly split into {split_count} parts") 

99 

100 split_size = dim_size // split_count 

101 final_shape = list(original_shape) 

102 if split_dim != concat_dim: 

103 final_shape[split_dim] = split_size 

104 final_shape[concat_dim] = final_shape[concat_dim] * split_count 

105 final_shape = tuple(final_shape) 

106 

107 pre_special_handle = all(original_shape[i] == 1 for i in range(split_dim)) 

108 if pre_special_handle: 

109 reshape_shape = (split_count * split_size,) + original_shape[split_dim + 1:] 

110 x_reshaped = x.view(reshape_shape) 

111 else: 

112 reshape_dims = list(original_shape) 

113 reshape_dims[split_dim] = split_count 

114 reshape_dims.insert(split_dim + 1, split_size) 

115 

116 trans_dims = list(range(len(reshape_dims))) 

117 trans_dims.remove(split_dim) 

118 trans_dims.insert(0, split_dim) 

119 

120 x_reshaped = x.reshape(reshape_dims).permute(trans_dims).contiguous() 

121 

122 reshape_shape = list(x_reshaped.shape) 

123 reshape_shape[0] = reshape_shape[0] * reshape_shape[1] 

124 reshape_shape.pop(1) 

125 reshape_shape = tuple(reshape_shape) 

126 x_reshaped = x_reshaped.reshape(reshape_shape) 

127 x_reshaped = x_reshaped.contiguous() 

128 output_tensor = platform.differentiable_all_to_all( 

129 input_data=x_reshaped, 

130 output_shape=reshape_shape, 

131 group=group 

132 ) 

133 

134 post_special_handle = all(final_shape[i] == 1 for i in range(concat_dim)) 

135 if post_special_handle: 

136 return output_tensor.view(final_shape) 

137 

138 # When pre_special_handle collapsed leading size-1 dims, the A2A was executed 

139 # in a reduced-rank space where the effective concat axis is shifted left by 

140 # split_dim positions. Use recon_concat_dim for all post-A2A reshaping so 

141 # that split_count is merged into the correct dimension. 

142 recon_concat_dim = (concat_dim - split_dim) if pre_special_handle else concat_dim 

143 

144 output_reshape = list(output_tensor.shape) 

145 output_reshape[0] = split_count 

146 output_reshape.insert(1, output_tensor.shape[0] // split_count) 

147 

148 out_trans_dims = list(range(len(output_reshape))) 

149 first_dim = out_trans_dims.pop(0) 

150 if recon_concat_dim >= len(out_trans_dims): 

151 out_trans_dims.append(first_dim) 

152 else: 

153 out_trans_dims.insert(recon_concat_dim, first_dim) 

154 

155 final_output = output_tensor.reshape(output_reshape).permute(out_trans_dims).contiguous() 

156 

157 final_reshape = list(final_output.shape) 

158 if recon_concat_dim < len(final_reshape) - 1: 

159 final_reshape[recon_concat_dim] = ( 

160 final_reshape[recon_concat_dim] * final_reshape[recon_concat_dim + 1] 

161 ) 

162 final_reshape.pop(recon_concat_dim + 1) 

163 

164 result = final_output.reshape(final_reshape) 

165 if pre_special_handle: 

166 result = result.view(final_shape) 

167 return result 

168 

169 def _apply_eazy_redistribute(self, src_layout, dst_layout): 

170 """_apply_eazy_redistribute""" 

171 if (src_layout.mesh_shape != dst_layout.mesh_shape or 

172 src_layout.rank_list != dst_layout.rank_list): 

173 return False 

174 

175 tensor_map_size = len(src_layout.tensor_map) 

176 if len(dst_layout.tensor_map) != tensor_map_size: 

177 return False 

178 return True 

179 

180 def _redistribution_without_shape(self, local_x, src_layout, dst_layout, key, rank_list): 

181 """_redistribution_without_shape""" 

182 inferrer = RedistributionOperatorInfer( 

183 dev_mat=src_layout.mesh_shape, 

184 in_tensor_map=list(src_layout.tensor_map), 

185 out_tensor_map=list(dst_layout.tensor_map) 

186 ) 

187 op_list = inferrer.infer_ops_list(self.rank_id, rank_list) 

188 self._transform_cache[key] = op_list 

189 for op in op_list: 

190 local_x = self._construct_op_operator[op[0]](local_x, *op[1]) 

191 return local_x 

192 

193 def redistribution(self, input_x, to_layout): 

194 """tensor redistribution""" 

195 x_layout = input_x.layout 

196 x = input_x 

197 if input_x.layout.is_partial(): 

198 # Solve partial status first 

199 if input_x.layout.mesh_shape == to_layout.mesh_shape: 

200 x = self.reduce_partial(input_x, to_layout) 

201 else: 

202 x = self.reduce_partial(input_x, x_layout) 

203 

204 from_layout = x.layout 

205 if not self.is_init: 

206 self.rank_id = platform.get_rank() 

207 self.is_init = True 

208 if from_layout.rank_list != to_layout.rank_list: 

209 raise ValueError(f"The from_layout rank list: {from_layout.rank_list} is not equal to " 

210 f"to_layout rank list: {to_layout.rank_list}") 

211 key = from_layout.compact_str + to_layout.compact_str + str(self.rank_id) 

212 if key in self._transform_cache: 

213 x = x.to_local() 

214 transform_operator_list = self._transform_cache[key] 

215 for transform_operator in transform_operator_list: 

216 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1]) 

217 return DTensor.from_local(x, to_layout.mesh, to_layout.alias_placements) 

218 

219 full_shape = x.shape 

220 key_and_shape = key + str(full_shape) 

221 x = x.to_local() 

222 if key_and_shape in self._transform_cache: 

223 transform_operator_list = self._transform_cache[key_and_shape] 

224 for transform_operator in transform_operator_list: 

225 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1]) 

226 return DTensor.from_local(x, to_layout.mesh, to_layout.alias_placements) 

227 

228 rank_list = from_layout.rank_list 

229 if self._apply_eazy_redistribute(from_layout, to_layout): 

230 if from_layout.is_partial(): 

231 from_layout.reset_partial() 

232 x = self._redistribution_without_shape(x, from_layout, to_layout, key, rank_list) 

233 else: 

234 transform_operator_list = self._infer_transform_operator_list(from_layout, to_layout, 

235 full_shape, key_and_shape, rank_list) 

236 for transform_operator in transform_operator_list: 

237 x = self._construct_op_operator[transform_operator[0]](x, *transform_operator[1]) 

238 return DTensor.from_local(x, to_layout.mesh, to_layout.alias_placements) 

239 

240 def _infer_transform_operator_list(self, from_layout, to_layout, from_full_shape, key, rank_list): 

241 """infer transform operator list""" 

242 from_layout_tuple, to_layout_tuple = \ 

243 _construct_layout_tuple_for_transform_operator_list(from_layout, to_layout, from_full_shape) 

244 self._transform_cache[key] = \ 

245 platform.get_tensor_transform().transform_tensor_sharding(from_layout_tuple, to_layout_tuple, 

246 rank_list, False, self.rank_id) 

247 return self._transform_cache[key] 

248 

249 @staticmethod 

250 def _allreduce_along_dev_dim(x, op, layout, dev_dim): 

251 """Do allreduce at specified axis along dev_dim.""" 

252 group = layout.get_comm_group_by_axis(dev_dim) 

253 zero_dim = x.dim() == 0 

254 if zero_dim: 

255 x = x.unsqueeze(0) 

256 if op == 'avg': 

257 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)] 

258 x = platform.differentiable_all_reduce(x, 'sum', group) 

259 x = x / dev_num 

260 elif op == 'all': 

261 x_int32 = platform.tensor_type_cast(x.bool(), 'int32') # True→1, False→0 

262 x = platform.differentiable_all_reduce(x_int32, 'all', group) 

263 x = x.bool() 

264 else: 

265 x = platform.differentiable_all_reduce(x, op, group) 

266 if zero_dim: 

267 x = x.squeeze(0) 

268 return x 

269 

270 def _reduce_scatter_along_dev_dim_with_axis(self, x, axis, op, layout, dev_dim): 

271 """Do reduce_scatter at specified axis along dev_dim.""" 

272 dev_num = layout.mesh_shape[layout.alias_name.index(dev_dim)] 

273 group = layout.get_comm_group_by_axis(dev_dim) 

274 output_tensor = platform.differentiable_reduce_scatter(x, dev_num, axis, op, group) 

275 return output_tensor 

276 

277 def reduce_partial(self, input_x, to_layout): 

278 """Reduce partial status.""" 

279 from_layout = input_x.layout 

280 x = input_x 

281 if from_layout is None or not from_layout.is_partial(): 

282 return x 

283 

284 x = x.to_local() 

285 if from_layout.mesh_shape != to_layout.mesh_shape: 

286 raise ValueError(f"For reduce partial, mesh_shape between from_layout and to_layout must be the same, " 

287 f"but got {from_layout.mesh_shape} and {to_layout.mesh_shape}") 

288 if to_layout.is_partial(): 

289 raise ValueError(f"For reduce partial, to_layout must be non-partial status, but got to_layout.partial: " 

290 f"{to_layout.partial}") 

291 

292 dev_map_order = {} 

293 for dev_axis in to_layout.alias_tensor_map: 

294 if isinstance(dev_axis, tuple): 

295 for i, sub_dev_axis in enumerate(dev_axis): 

296 dev_map_order[sub_dev_axis] = i 

297 else: 

298 dev_map_order[dev_axis] = 0 

299 

300 pending_reduce_op_list = [] # List[Tuple[comm_op, op, dev_dim, reduce_dim]] 

301 for dev_axis_index, op in enumerate(from_layout.partial): 

302 if op is None: 

303 continue 

304 dev_axis = from_layout.alias_name[dev_axis_index] 

305 apply_shard_dim = to_layout.get_dev_axis_apply_shard_axis(dev_axis) 

306 comm_op = "ReduceScatter" if apply_shard_dim else "AllReduce" 

307 pending_reduce_op_list.append((comm_op, op, dev_axis, apply_shard_dim)) 

308 

309 # sort reduce op 

310 # 1. ReduceScatter is executed before AllReduce 

311 # 2. If multiple split, the dev axis split outer will be execute first. 

312 # e.g. ("cp", "tp"), will execute reduce_scatter along "cp" before "tp" 

313 # 3. Lower dev_id execute before higher dev_id 

314 sorted_pending_reduce_op_list = \ 

315 sorted(pending_reduce_op_list, key=lambda reduce_pair: (reduce_pair[0] != "ReduceScatter", 

316 dev_map_order.get(reduce_pair[2], 0), 

317 to_layout.mesh.axis_id(reduce_pair[2]))) 

318 

319 output_alias_tensor_map = list(from_layout.alias_tensor_map) 

320 for reduce_op_pair in sorted_pending_reduce_op_list: 

321 comm_op = reduce_op_pair[0] 

322 op = reduce_op_pair[1] 

323 dev_axis = reduce_op_pair[2] 

324 if comm_op == "AllReduce": 

325 x = TensorRedistribution._allreduce_along_dev_dim(x, op, from_layout, dev_axis) 

326 elif comm_op == "ReduceScatter": 

327 reduce_axis = reduce_op_pair[3] 

328 x = self._reduce_scatter_along_dev_dim_with_axis(x, reduce_axis, op, from_layout, dev_axis) 

329 if output_alias_tensor_map[reduce_axis] == "None": 

330 output_alias_tensor_map[reduce_axis] = dev_axis 

331 elif isinstance(output_alias_tensor_map[reduce_axis], tuple): 

332 output_alias_tensor_map[reduce_axis] += (dev_axis,) 

333 else: 

334 output_alias_tensor_map[reduce_axis] = (output_alias_tensor_map[reduce_axis], dev_axis) 

335 

336 output_layout = from_layout(*output_alias_tensor_map) 

337 output_layout.reset_partial() 

338 return DTensor.from_local(x, output_layout.mesh, output_layout.alias_placements) 

339 

340 

341_tensor_redistribution = TensorRedistribution()