Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_gather.py: 70%

211 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Gather operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from hyper_parallel.platform import get_platform 

21from .parallel_ops import DistributedOp 

22 

23 

24class IndexSelectDistributedOp(DistributedOp): 

25 """Distributed implementation for Index Select operator.""" 

26 

27 def infer_layout(self, layouts, extra_args=None): 

28 """ 

29 Infer output layouts for Index Select operations. 

30 

31 Args: 

32 layouts: Layouts of input tensors 

33 extra_args: extra_args of input tensors 

34 

35 Returns: 

36 tuple: Layout for output tensor. 

37 

38 Raises: 

39 ValueError: If input layouts are not compatible or have partial status. 

40 """ 

41 # Check partial inputs 

42 if not self._allow_partial_inputs: 

43 self._check_partial_inputs(layouts) 

44 

45 # Check inputs 

46 if len(layouts) != 3: 

47 raise ValueError(f"Gather ops requires 3 layouts, but {len(layouts)}") 

48 if len(extra_args) != 1: 

49 raise ValueError(f"Gather ops requires 1 extra args, but {len(extra_args)}") 

50 

51 # Parse layout info 

52 p_layout, i_layout = layouts[0], layouts[2] 

53 axis = extra_args[0] 

54 

55 p_tensor_map = p_layout.alias_tensor_map 

56 i_tensor_map = i_layout.alias_tensor_map 

57 

58 # 1. Validate the axis range before any manipulation 

59 if axis < -len(p_tensor_map) or axis >= len(p_tensor_map): 

60 raise ValueError( 

61 f"Operation {self.op_name}: dim value {axis} is out of valid range" 

62 ) 

63 

64 # 2. Convert negative axis to positive index to avoid Python slicing bugs 

65 if axis < 0: 

66 axis += len(p_tensor_map) 

67 

68 if len(i_tensor_map) != 1: 

69 raise ValueError( 

70 f"Operation {self.op_name}: index is not a one-dimensional Tensor" 

71 ) 

72 

73 # 3. Create output layout map 

74 # We allow sharding on the `axis`. Since `index_select` replaces the `axis` 

75 # dimension with the `index` dimension, if `axis` was sharded, that mesh 

76 # dimension is removed from the output tensor map. 

77 output_tensor_map = list(p_tensor_map[:axis]) + list(i_tensor_map) + list(p_tensor_map[axis + 1 :]) 

78 

79 output_layout = Layout( 

80 mesh_shape=p_layout.mesh_shape, 

81 alias_name=p_layout.alias_name, 

82 rank_list=p_layout.rank_list, 

83 ) 

84 output_layout = output_layout(*output_tensor_map) 

85 

86 # 4. Implicit Communication via Partial Layout 

87 # If the gather axis was sharded, the local output will only be a masked partial result. 

88 # We set the output layout to Partial('sum') for that specific mesh dimension so the 

89 # OpDispatcher handles the AllReduce automatically when this tensor is used later. 

90 shard_mesh_dim_name = p_tensor_map[axis] 

91 if shard_mesh_dim_name != "None": 

92 # Handle possible multi-axis sharding tuple 

93 if isinstance(shard_mesh_dim_name, tuple): 

94 for dim_name in shard_mesh_dim_name: 

95 if dim_name != "None": 

96 output_layout.set_partial_by_dev_axis(dim_name, 'sum') 

97 else: 

98 output_layout.set_partial_by_dev_axis(shard_mesh_dim_name, 'sum') 

99 

100 return output_layout 

101 

102 def get_expand_impl(self, func, infer_result, layouts, extra_args=None): 

103 """ 

104 Get the expanded execution implementation for Index Select. 

105 """ 

106 p_layout = layouts[0] 

107 axis = extra_args[0] 

108 if axis < 0: 

109 axis += len(p_layout.alias_tensor_map) 

110 

111 shard_mesh_dim_name = p_layout.alias_tensor_map[axis] 

112 

113 # If the axis is NOT sharded, fallback to standard execution 

114 if shard_mesh_dim_name == "None": 

115 return func 

116 

117 # If the axis IS sharded, return a custom function with Masking ONLY. 

118 # The explicit AllReduce is completely removed. 

119 def expand_impl(input_tensor, dim, index, **kwargs): 

120 platform = get_platform() 

121 mesh = p_layout.mesh 

122 

123 # Fetch the communication group for the sharded mesh dimension 

124 if isinstance(shard_mesh_dim_name, tuple): 

125 target_dim_name = next(d for d in shard_mesh_dim_name if d != "None") 

126 else: 

127 target_dim_name = shard_mesh_dim_name 

128 

129 comm_group_info = mesh.get_comm_group_by_axis(target_dim_name) 

130 group = comm_group_info.group if hasattr(comm_group_info, 'group') else comm_group_info 

131 

132 # Get the rank of the current device within this specific communication group 

133 group_rank = platform.get_group_local_rank(group=group) 

134 

135 # Calculate global index boundaries for the local chunk 

136 local_dim_size = input_tensor.shape[dim] 

137 start_idx = group_rank * local_dim_size 

138 end_idx = start_idx + local_dim_size 

139 

140 # 1. Compute mask: True for indices that belong to the current rank 

141 mask = (index >= start_idx) & (index < end_idx) 

142 

143 # 2. Shift global indices to local indices 

144 safe_index = index - start_idx 

145 

146 # Clamp safe_index to valid local ranges to prevent CUDA out-of-bounds 

147 # errors during the local index_select (invalid ones will be masked out anyway). 

148 safe_index = safe_index.clamp(min=0, max=local_dim_size - 1) 

149 

150 # 3. Perform local index_select using tensor's built-in method 

151 local_out = input_tensor.index_select(dim, safe_index, **kwargs) 

152 

153 # 4. Mask out the invalid indices (set them to 0) 

154 # Reshape the 1D mask to broadcast against the output shape 

155 mask_shape = [1] * local_out.ndim 

156 mask_shape[dim] = -1 

157 mask_reshaped = mask.reshape(mask_shape).to(local_out.dtype) 

158 

159 local_out = local_out * mask_reshaped 

160 

161 # Return the partial local tensor directly. The framework's layout engine 

162 # and OpDispatcher will trigger the AllReduce when this Partial tensor 

163 # is redistributed to a non-partial layout. 

164 return local_out 

165 

166 return expand_impl 

167 

168 

169class GatherDDistributedOp(DistributedOp): 

170 """Distributed implementation for GatherD operator. 

171  

172 GatherD gathers values along a specified axis from the input tensor using the index tensor. 

173  

174 Signature: GatherD(input, dim, index) -> output 

175  

176 Key constraints: 

177 - Input and index must have the same number of dimensions 

178 - Output inherits the sharding pattern of the input tensor 

179 """ 

180 

181 def infer_layout(self, layouts, extra_args=None): 

182 """ 

183 Infer output layouts for GatherD operations. 

184 Args: 

185 layouts: Layouts of input tensors [input_layout, dim_layout, index_layout] 

186 extra_args: Extra arguments containing [dim] 

187 Returns: 

188 Layout: Layout for output tensor. 

189 Raises: 

190 ValueError: If input layouts are not compatible or have partial status. 

191 """ 

192 # Check partial inputs 

193 if not self._allow_partial_inputs: 

194 self._check_partial_inputs(layouts) 

195 

196 # Validate input count 

197 if len(layouts) != 3: 

198 raise ValueError( 

199 f"Operation {self.op_name}: requires 3 layouts (input, dim, index), " 

200 f"but got {len(layouts)}" 

201 ) 

202 # Validate extra_args (should contain dim) 

203 if len(extra_args) != 1: 

204 raise ValueError( 

205 f"Operation {self.op_name}: requires 1 extra arg (dim), " 

206 f"but got {len(extra_args)}" 

207 ) 

208 # Parse layouts: [input, dim (non-tensor), index] 

209 # Note: dim is a scalar, so layouts[1] should be None 

210 input_layout = layouts[0] 

211 index_layout = layouts[2] 

212 dim = extra_args[0] 

213 # Validate layouts exist 

214 if input_layout is None or not hasattr(input_layout, "tensor_map"): 

215 raise ValueError(f"Operation {self.op_name}: input layout cannot be None") 

216 if index_layout is None or not hasattr(index_layout, "tensor_map"): 

217 raise ValueError(f"Operation {self.op_name}: index layout cannot be None") 

218 input_tensor_map = input_layout.tensor_map 

219 index_tensor_map = index_layout.tensor_map 

220 # Validate same rank 

221 if len(input_tensor_map) != len(index_tensor_map): 

222 raise ValueError( 

223 f"Operation {self.op_name}: input and index must have the same number of dimensions. " 

224 f"Got input rank={len(input_tensor_map)}, index rank={len(index_tensor_map)}" 

225 ) 

226 # Validate dim is in valid range 

227 rank = len(input_tensor_map) 

228 if dim < -rank or dim >= rank: 

229 raise ValueError( 

230 f"Operation {self.op_name}: dim value {dim} is out of valid range [{-rank}, {rank-1}]" 

231 ) 

232 # Normalize negative dim 

233 if dim < 0: 

234 dim = dim + rank 

235 for axis, (input_axis_map, index_axis_map) in enumerate(zip(input_tensor_map, index_tensor_map)): 

236 if axis == dim: 

237 continue 

238 if input_axis_map != index_axis_map: 

239 raise ValueError( 

240 f"Operation {self.op_name}: input and index must use the same sharding on non-dim axis {axis}. " 

241 f"Got input tensor_map={input_tensor_map}, index tensor_map={index_tensor_map}, dim={dim}" 

242 ) 

243 # Output inherits index layout 

244 output_layout = Layout( 

245 mesh_shape=index_layout.mesh_shape, 

246 alias_name=index_layout.alias_name, 

247 rank_list=index_layout.rank_list, 

248 ) 

249 output_layout.set_tensor_map(index_layout.tensor_map) 

250 if input_tensor_map[dim] != -1: 

251 # pylint: disable=protected-access 

252 # Inherit current partial state from index layout 

253 output_layout._partial = list(index_layout.partial) 

254 # Calculate the device axis name for the dim dimension 

255 # tensor_map uses reverse indexing: tensor_map[i] = len(alias_name) - 1 - device_axis 

256 device_axis_idx = len(index_layout.alias_name) - 1 - input_tensor_map[dim] 

257 dim_axis_name = index_layout.alias_name[device_axis_idx] 

258 output_layout.set_partial_by_dev_axis(dim_axis_name, 'sum') 

259 # pylint: disable=protected-access 

260 # Rebuild readable alias tensor map 

261 output_layout._alias_tensor_map = output_layout._build_readable_tensor_map() 

262 # pylint: disable=protected-access 

263 # Sync tensor_map to placement representation 

264 output_layout.tensor_map_to_placement() 

265 # Update compact string description 

266 output_layout.update_compact_str() 

267 return output_layout 

268 

269 def get_expand_impl(self, func, infer_result, layouts, extra_args=None): 

270 """ 

271 Returns the execution implementation wrapper for distributed GatherD. 

272  

273 When the dim axis is sharded, each rank gathers from its local slice of the input tensor. 

274 The indices need to be adjusted to account for the local partition offset. 

275  

276 Args: 

277 func: The original GatherD function to wrap 

278 output_layout: The inferred output layout 

279 layouts: Layouts of input tensors [input_layout, dim_layout, index_layout] 

280 extra_args: Extra arguments containing [dim] 

281  

282 Returns: 

283 Callable: Distributed implementation wrapper, or None if no sharding 

284 """ 

285 input_layout = layouts[0] 

286 dim = extra_args[0] 

287 # Get tensor maps 

288 input_tensor_map = input_layout.tensor_map 

289 # Check if dim axis is sharded (enhanced MP) 

290 # tensor_map[dim] == -1 means replicated, otherwise sharded 

291 if input_tensor_map[dim] == -1: # native sharding, no need for custom implementation 

292 return None 

293 

294 def distributed_gatherd_impl(*args, **kwargs): 

295 """ 

296 Distributed GatherD implementation for sharded dim axis. 

297  

298 Each rank gathers from its local slice of input tensor. 

299 Indices are adjusted by subtracting the local partition offset. 

300 """ 

301 input_tensor = args[0] 

302 index_tensor = args[2] 

303 # Calculate local partition offset for the dim axis 

304 mesh = input_layout.mesh 

305 # Convert tensor_map index to mesh axis index (reverse order) 

306 mesh_dim_idx = len(mesh.mesh_shape) - 1 - input_tensor_map[dim] 

307 # Get the coordinate of current rank along the mesh dimension 

308 dim_coord = mesh.get_local_rank(mesh_dim_idx) 

309 # Calculate the size of input tensor's dim dimension per partition 

310 input_dim_size = input_tensor.shape[dim] 

311 # Calculate the starting index of local partition 

312 local_start_index = int(dim_coord * input_dim_size) 

313 local_end_index = int(local_start_index + input_dim_size) 

314 # Adjust indices: subtract local_start_index to map global indices to local range 

315 # This is similar to how Embedding shifts indices for Row Parallelism 

316 adjusted_index = index_tensor - local_start_index 

317 # Create mask to identify out-of-bounds indices 

318 # Indices outside [0, local_dim_size) belong to other partitions 

319 mask = (index_tensor >= local_start_index) & (index_tensor < local_end_index) 

320 # Cross-platform cast to matching int dtype 

321 mask_int = mask.to(index_tensor.dtype) 

322 # Zero out invalid indices to prevent out-of-bounds access 

323 safe_index = adjusted_index * mask_int 

324 # Replace original index tensor with adjusted index 

325 new_args = list(args) 

326 new_args[2] = safe_index 

327 # Execute native GatherD with adjusted indices 

328 output = func(*new_args, **kwargs) 

329 # Zero out outputs corresponding to invalid indices 

330 mask_int = mask_int.to(output.dtype) 

331 output = output * mask_int 

332 return output 

333 return distributed_gatherd_impl 

334 

335 

336class GatherNdDistributedOp(DistributedOp): 

337 """Distributed implementation for GatherNd operator.""" 

338 

339 def infer_layout(self, layouts, extra_args=None): 

340 """ 

341 Infer output layout for GatherNd. 

342 

343 For GatherNd: out.shape = indices.shape[:-1] + input_x.shape[K:], where K = indices.shape[-1]. 

344 

345 This implementation: 

346 - Inherits sharding from indices[:-1]. 

347 - Allows sharding on input_x trailing dims input_x[K:]. 

348 - Requires input_x[:K] to be replicated ("None") if input_layout is provided. 

349 - Requires indices[-1] (K dim) to be replicated ("None"). 

350 

351 Output Layout: 

352 output_tensor_map = indices_tensor_map[:-1] + input_tensor_map[K:] 

353 If input_layout is None, input trailing dims are treated as replicated ("None"). 

354 """ 

355 input_layout, indices_layout = self._parse_input_layouts(layouts) 

356 

357 input_shape, indices_shape = self._get_input_shapes(extra_args) 

358 k, trail_rank = self._get_k_and_trailing_rank(input_shape, indices_shape) 

359 

360 input_tensor_map, indices_tensor_map = self._validate_tensor_maps( 

361 input_layout, indices_layout, k 

362 ) 

363 

364 # Output sharding: inherit indices[:-1] + input_x[K:]. 

365 if input_tensor_map is None: 

366 output_tensor_map = tuple(indices_tensor_map[:-1]) + ("None",) * trail_rank 

367 else: 

368 output_tensor_map = tuple(indices_tensor_map[:-1]) + tuple(input_tensor_map[k:]) 

369 

370 output_layout = Layout( 

371 mesh_shape=indices_layout.mesh_shape, 

372 alias_name=indices_layout.alias_name, 

373 rank_list=indices_layout.rank_list, 

374 ) 

375 

376 if output_tensor_map: 

377 output_layout = output_layout(*output_tensor_map) 

378 else: 

379 output_layout = output_layout("None") 

380 

381 return output_layout 

382 

383 def _parse_input_layouts(self, layouts): 

384 """Parse and validate input layouts.""" 

385 if len(layouts) < 2: 

386 raise ValueError( 

387 f"Operation {self.op_name} requires at least 2 input layouts, but got {len(layouts)}" 

388 ) 

389 

390 input_layout, indices_layout = layouts[0], layouts[1] 

391 

392 # Extra inputs are allowed only when they are non-tensor args (layout is None). 

393 for extra_layout in layouts[2:]: 

394 if extra_layout is not None: 

395 raise ValueError( 

396 f"Operation {self.op_name} only supports 2 tensor inputs, but got extra tensor layout: " 

397 f"{extra_layout}" 

398 ) 

399 

400 # For GatherNd: input_layout can be None (treated as fully replicated), but indices_layout must exist. 

401 if indices_layout is None or not hasattr(indices_layout, "alias_tensor_map"): 

402 raise ValueError(f"Operation {self.op_name}: Indices layout cannot be None") 

403 

404 return input_layout, indices_layout 

405 

406 def _validate_tensor_maps(self, input_layout, indices_layout, k): 

407 """Validate tensor maps constraints for GatherNd.""" 

408 indices_tensor_map = indices_layout.alias_tensor_map 

409 

410 # Validate: indices tensor_map must exist and last dimension cannot be split. 

411 if not indices_tensor_map: 

412 raise ValueError(f"Operation {self.op_name}: indices tensor_map cannot be empty") 

413 

414 last_axis = indices_tensor_map[-1] 

415 if not self._is_none_axis(last_axis): 

416 raise ValueError( 

417 f"Operation {self.op_name}: The last dimension of indices cannot be split. " 

418 f"Got indices[-1] = {last_axis}" 

419 ) 

420 

421 # Validate input only when layout is provided. 

422 input_tensor_map = None 

423 if input_layout is not None: 

424 input_tensor_map = input_layout.alias_tensor_map 

425 

426 if k > len(input_tensor_map): 

427 raise ValueError( 

428 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank " 

429 f"({len(input_tensor_map)})" 

430 ) 

431 

432 # Indexed dims [0:K) must be replicated. 

433 for axis_name in input_tensor_map[:k]: 

434 if not self._is_none_axis(axis_name): 

435 raise ValueError( 

436 f"Operation {self.op_name}: input_x cannot be split on indexed dims [0:{k}). " 

437 f"These dims must be 'None', but got tensor_map: {input_tensor_map}" 

438 ) 

439 

440 return input_tensor_map, indices_tensor_map 

441 

442 def _get_input_shapes(self, extra_args): 

443 """Get input and indices shapes from extra_args (WithShape suffix required).""" 

444 input_shapes = None 

445 if extra_args and hasattr(extra_args[-1], "__len__") and len(extra_args[-1]) >= 2: 

446 input_shapes = extra_args[-1] 

447 

448 if input_shapes is None: 

449 raise ValueError( 

450 f"Operation {self.op_name}: missing input_shapes in extra_args. " 

451 f"Please configure yaml with infer_layout_suffix: WithShape." 

452 ) 

453 

454 input_shape = input_shapes[0] 

455 indices_shape = input_shapes[1] 

456 if input_shape is None or indices_shape is None: 

457 raise ValueError(f"Operation {self.op_name}: input_shapes contains None: {input_shapes}") 

458 

459 input_shape = self._normalize_shape(input_shape, "input") 

460 indices_shape = self._normalize_shape(indices_shape, "indices") 

461 

462 if len(indices_shape) < 1: 

463 raise ValueError(f"Operation {self.op_name}: indices shape invalid: {indices_shape}") 

464 

465 return input_shape, indices_shape 

466 

467 def _normalize_shape(self, shape, name): 

468 """Normalize shape-like object to tuple of int.""" 

469 try: 

470 norm = tuple(shape) 

471 except TypeError as err: 

472 raise ValueError(f"Operation {self.op_name}: {name} shape is not iterable: {shape}") from err 

473 

474 try: 

475 norm = tuple(int(dim) for dim in norm) 

476 except (TypeError, ValueError) as err: 

477 raise ValueError(f"Operation {self.op_name}: {name} shape contains non-integer dims: {norm}") from err 

478 

479 return norm 

480 

481 def _get_k_and_trailing_rank(self, input_shape, indices_shape): 

482 """Compute K and trailing rank = len(input_shape) - K, where K is indices_shape[-1].""" 

483 k = indices_shape[-1] 

484 try: 

485 k = int(k) 

486 except (TypeError, ValueError) as err: 

487 raise ValueError(f"Operation {self.op_name}: indices last dim (K) is invalid: {k}") from err 

488 

489 if k <= 0: 

490 raise ValueError(f"Operation {self.op_name}: indices last dim (K) must be positive, but got {k}") 

491 

492 trail_rank = len(input_shape) - k 

493 if trail_rank < 0: 

494 raise ValueError( 

495 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank ({len(input_shape)})" 

496 ) 

497 

498 return k, trail_rank 

499 

500 def _is_none_axis(self, axis_name): 

501 """ 

502 Check if an axis name represents no sharding. 

503 """ 

504 if axis_name == "None": 

505 return True 

506 

507 if isinstance(axis_name, tuple): 

508 return all(name == "None" for name in axis_name) 

509 

510 return False