Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / param_group.py: 18%

451 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore HSDP parameter group with fused communication.""" 

16 

17from __future__ import annotations 

18 

19import math 

20from dataclasses import dataclass, field 

21from typing import Any, List, NamedTuple, Optional 

22 

23import mindspore as ms 

24from mindspore import ops 

25from mindspore.common.api import _no_grad 

26import mindspore.mint.distributed as dist 

27from mindspore.ops.function.comm_func import CommHandle 

28 

29from hyper_parallel.core.fully_shard.utils import DDPMeshInfo, FSDPMeshInfo, HSDPMeshInfo, MixedPrecisionPolicy 

30from hyper_parallel.platform.mindspore.fully_shard._version_utils import copy_without_bumping_version 

31from hyper_parallel.platform.mindspore.fully_shard.pack_utils import build_rs_plan, pack_for_reduce_scatter 

32from hyper_parallel.platform.mindspore.fully_shard.param import MindSporeHSDPParamV2 

33 

34 

35def _normalize_device(device: Any) -> str: 

36 if isinstance(device, str): 

37 return device.split(":", 1)[0] 

38 return str(device).split(":", 1)[0] 

39 

40 

41def _shape_numel(shape) -> int: 

42 return math.prod(int(dim) for dim in shape) 

43 

44 

45def get_all_gather_metadata(hsdp_params): 

46 """Collect metadata required for fused all-gather.""" 

47 param_input_dtypes = [] 

48 param_input_numels = [] 

49 inp_split_sizes = [] 

50 total_input_numel = 0 

51 first_dtype = None 

52 

53 for hsdp_param in hsdp_params: 

54 inputs = hsdp_param.all_gather_inputs 

55 if first_dtype is None: 

56 first_dtype = inputs[0].dtype 

57 elif first_dtype != inputs[0].dtype: 

58 raise ValueError("All parameters in the group must have a uniform dtype.") 

59 param_dtypes = [t.dtype for t in inputs] 

60 param_numels = [t.numel() for t in inputs] 

61 param_input_dtypes.append(param_dtypes) 

62 param_input_numels.append(param_numels) 

63 inp_split_sizes.extend(param_numels) 

64 total_input_numel += sum(param_numels) 

65 

66 return AllGatherMetadata( 

67 param_input_dtypes, 

68 param_input_numels, 

69 first_dtype, 

70 inp_split_sizes, 

71 total_input_numel, 

72 ) 

73 

74 

75@dataclass 

76class AllGatherMetadata: 

77 """Metadata describing the fused all-gather buffer layout.""" 

78 

79 param_input_dtypes: list[list[Any]] 

80 param_input_numels: list[list[int]] 

81 dtype: Any 

82 inp_split_sizes: list[int] 

83 total_input_numel: int 

84 hash_key: int = field(init=False) 

85 

86 def __post_init__(self): 

87 self.hash_key = hash( 

88 ( 

89 tuple(tuple(d) for d in self.param_input_dtypes), 

90 tuple(tuple(n) for n in self.param_input_numels), 

91 self.dtype, 

92 tuple(self.inp_split_sizes), 

93 self.total_input_numel, 

94 ) 

95 ) 

96 

97 

98class AllGatherResult(NamedTuple): 

99 """Result of a fused all-gather operation.""" 

100 

101 all_gather_output: Optional[ms.Tensor] 

102 metadata: Optional[AllGatherMetadata] 

103 handle: Optional[CommHandle] 

104 

105 

106@dataclass 

107class CommContext: 

108 """Global communication context for pipelined fused reductions.""" 

109 

110 comm_handle: Optional[CommHandle] = None 

111 all_reduce_handle: Optional[CommHandle] = None 

112 pre_param_group = None 

113 all_reduce_param_group = None 

114 

115 

116comm_ctx = CommContext() 

117 

118 

119def get_comm_ctx(): 

120 """Return the global communication context singleton.""" 

121 return comm_ctx 

122 

123 

124@dataclass 

125class ReplicateBucket: 

126 """One fused all-reduce bucket sharing the same replicate process group.""" 

127 

128 key: int 

129 group: Any 

130 group_size: int 

131 param_indices: list[int] 

132 flat_numel: int 

133 buffer: Optional[ms.Tensor] = None 

134 

135 

136@dataclass 

137class PendingBucketAllReduce: 

138 """One in-flight async all-reduce launched for a replicate bucket.""" 

139 

140 bucket_key: int 

141 handle: Any 

142 

143 

144class AllGatherMetadataCache: 

145 """Cache for all-gather metadata across iterations.""" 

146 

147 _cache: dict[int, AllGatherMetadata] = {} 

148 

149 @classmethod 

150 def get_metadata(cls, hsdp_params, fn): 

151 param_key = tuple((id(p), getattr(p, "version", 0)) for p in hsdp_params) 

152 key = hash(param_key) 

153 if key in cls._cache: 

154 return cls._cache[key] 

155 metadata = fn(hsdp_params) 

156 cls._cache[key] = metadata 

157 return metadata 

158 

159 

160@_no_grad() 

161def all_gather_copy_in(all_gather_inputs, all_gather_output, inp_split_sizes, all_gather_input_numel, rank): 

162 """Copy per-parameter local shards into one fused rank-local all-gather slice.""" 

163 all_gather_input = all_gather_output.narrow(0, all_gather_input_numel * rank, all_gather_input_numel) 

164 offset = 0 

165 for src, size in zip(all_gather_inputs, inp_split_sizes): 

166 src_flat = src.view(-1) 

167 all_gather_input.narrow(0, offset, size).copy_(src_flat) 

168 offset += size 

169 return all_gather_input, all_gather_output 

170 

171 

172@_no_grad() 

173def split_with_sizes_copy(all_gather_output, split_sizes, dim, out): 

174 """Copy split views from a fused all-gather output into pre-allocated outputs.""" 

175 if dim != 1: 

176 raise NotImplementedError("split_with_sizes_copy currently only supports dim=1") 

177 offset = 0 

178 for dst, size in zip(out, split_sizes): 

179 src = all_gather_output.narrow(dim, offset, size) 

180 copy_without_bumping_version(dst, src) 

181 offset += size 

182 

183 

184@_no_grad() 

185def reduce_scatter_copy_in( 

186 hsdp_params: List[MindSporeHSDPParamV2], 

187 unsharded_grads: List[ms.Tensor], 

188 reduce_scatter_input: ms.Tensor, 

189 world_size: int, 

190) -> None: 

191 """Pack all unsharded gradients into one fused reduce-scatter input buffer.""" 

192 if len(hsdp_params) != len(unsharded_grads): 

193 raise AssertionError( 

194 "reduce_scatter_copy_in expects one hsdp_param per unsharded_grad, but got " 

195 f"{len(hsdp_params)} params and {len(unsharded_grads)} grads" 

196 ) 

197 packed_rows = reduce_scatter_input.view(world_size, -1) 

198 col_offset = 0 

199 for hsdp_param, grad in zip(hsdp_params, unsharded_grads): 

200 grad = grad.contiguous() 

201 plan = build_rs_plan(hsdp_param, grad, world_size) 

202 packed_grad = pack_for_reduce_scatter(grad, plan) 

203 next_col_offset = col_offset + packed_grad.shape[1] 

204 for row_idx in range(world_size): 

205 packed_rows[row_idx].narrow(0, col_offset, packed_grad.shape[1]).copy_( 

206 packed_grad[row_idx].view(-1) 

207 ) 

208 col_offset = next_col_offset 

209 if col_offset != packed_rows.shape[1]: 

210 raise AssertionError( 

211 "reduce_scatter_copy_in packed an unexpected number of elements: " 

212 f"{col_offset} != {packed_rows.shape[1]}" 

213 ) 

214 

215 

216class HSDPParamGroup: 

217 """Group HSDP parameters within a module for fused collectives.""" 

218 

219 def __init__( 

220 self, 

221 hsdp_params, 

222 mesh_info: FSDPMeshInfo, 

223 device: Optional[str] = None, 

224 mp_policy: Optional[MixedPrecisionPolicy] = None, 

225 enable_zero_copy_param_buffer: bool = False, 

226 ): 

227 self.mesh_info = mesh_info 

228 self.device = device 

229 self.hsdp_params = hsdp_params 

230 self.enable_zero_copy_param_buffer = enable_zero_copy_param_buffer 

231 if isinstance(self.mesh_info, (FSDPMeshInfo, HSDPMeshInfo)): 

232 self.shard_rank = self.mesh_info.shard_mesh_rank 

233 self.shard_world_size = self.mesh_info.shard_mesh_size 

234 else: 

235 self.shard_rank = 0 

236 self.shard_world_size = 1 

237 self.shard_group = self.mesh_info.shard_process_group 

238 self.replicate_group = None 

239 if isinstance(self.mesh_info, (HSDPMeshInfo, DDPMeshInfo)): 

240 self.replicate_group = self.mesh_info.replicate_process_group 

241 elif isinstance(self.mesh_info, FSDPMeshInfo): 

242 self.replicate_group = self._infer_layout_replicate_group() 

243 self.ag_output: Optional[ms.Tensor] = None 

244 self.metadata_cache = None 

245 self.mp_policy = mp_policy 

246 self._result = None 

247 self._reduce_output = None 

248 self._reduce_op = None 

249 self._needs_avg_div = False 

250 self._reduce_hsdp_params = None 

251 self._active_replicate_buckets: dict[int, ReplicateBucket] = {} 

252 self._active_param_flat_offsets: list[int] = [] 

253 self._pending_all_reduce_handles: list[PendingBucketAllReduce] = [] 

254 self._flat_param_buffer: Optional[ms.Tensor] = None 

255 self._flat_cast_buffer: Optional[ms.Tensor] = None 

256 self._init_mp_dtypes() 

257 if self.enable_zero_copy_param_buffer: 

258 self._init_flat_param_buffer() 

259 

260 def _infer_layout_replicate_group(self): 

261 replicate_groups = [] 

262 for hsdp_param in self.hsdp_params: 

263 group_info = getattr(hsdp_param, "unsharded_group_info", None) 

264 group = getattr(group_info, "group", None) 

265 if group is None or getattr(hsdp_param, "replicate_world_size", 1) <= 1: 

266 continue 

267 replicate_groups.append(group) 

268 if not replicate_groups: 

269 return None 

270 return replicate_groups[0] 

271 

272 def _build_active_replicate_buckets(self, hsdp_params): 

273 buckets: dict[int, ReplicateBucket] = {} 

274 for idx, hsdp_param in enumerate(hsdp_params): 

275 group_info = getattr(hsdp_param, "unsharded_group_info", None) 

276 group = getattr(group_info, "group", None) 

277 group_size = getattr(group_info, "rank_size", getattr(hsdp_param, "replicate_world_size", 1)) 

278 if group is None or group_size <= 1: 

279 continue 

280 key = id(group) 

281 if key not in buckets: 

282 buckets[key] = ReplicateBucket( 

283 key=key, 

284 group=group, 

285 group_size=group_size, 

286 param_indices=[], 

287 flat_numel=0, 

288 ) 

289 buckets[key].param_indices.append(idx) 

290 buckets[key].flat_numel += _shape_numel(hsdp_param.sharded_size) 

291 return buckets 

292 

293 def _init_flat_param_buffer(self): 

294 """Rebase local shards into one flat buffer when storage semantics allow it.""" 

295 if not self.enable_zero_copy_param_buffer: 

296 return 

297 if self.shard_world_size <= 1 or len(self.hsdp_params) == 0: 

298 return 

299 if any(p.offload_to_cpu or str(p.sharded_param.device) == "meta" for p in self.hsdp_params): 

300 return 

301 

302 total_numel = sum(hsdp_param._sharded_param_data.numel() for hsdp_param in self.hsdp_params) 

303 orig_dtype = self.hsdp_params[0]._sharded_param_data.dtype 

304 flat_buffer = ms.mint.empty((total_numel,), dtype=orig_dtype, device=_normalize_device(self.device)) 

305 

306 offset = 0 

307 original_locals = [] 

308 try: 

309 for hsdp_param in self.hsdp_params: 

310 original_locals.append((hsdp_param, hsdp_param._sharded_param_data, hsdp_param._sharded_local_tensor)) 

311 numel = hsdp_param._sharded_param_data.numel() 

312 flat_slice = flat_buffer.narrow(0, offset, numel) 

313 flat_slice.copy_(hsdp_param._sharded_param_data) 

314 hsdp_param._sharded_param_data = flat_slice 

315 new_local = flat_slice.view(hsdp_param.sharded_size) 

316 req_grad = hsdp_param.sharded_param.requires_grad 

317 hsdp_param.sharded_param.set_data(new_local) 

318 hsdp_param.sharded_param._local_tensor = new_local 

319 if req_grad: 

320 new_local.requires_grad_(True) 

321 hsdp_param.sharded_param.requires_grad_(True) 

322 offset += numel 

323 except Exception: # pylint: disable=W0718 

324 for hsdp_param, orig_flat, orig_local in original_locals: 

325 hsdp_param._sharded_param_data = orig_flat 

326 hsdp_param.sharded_param.set_data(orig_local) 

327 hsdp_param.sharded_param._local_tensor = orig_local 

328 self._flat_param_buffer = None 

329 self._flat_cast_buffer = None 

330 return 

331 

332 self._flat_param_buffer = flat_buffer 

333 has_param_dtype = any(p.param_dtype is not None for p in self.hsdp_params) 

334 if has_param_dtype: 

335 cast_dtype = next(p.param_dtype for p in self.hsdp_params if p.param_dtype is not None) 

336 self._flat_cast_buffer = ms.mint.empty( 

337 (total_numel,), dtype=cast_dtype, device=_normalize_device(self.device) 

338 ) 

339 

340 def _is_flat_buffer_valid(self): 

341 """Check if flat buffer still backs the params' sharded data.""" 

342 if self._flat_param_buffer is None or len(self.hsdp_params) == 0: 

343 return False 

344 first_param = self.hsdp_params[0] 

345 return ( 

346 first_param._sharded_param_data.untyped_storage().data_ptr() 

347 == self._flat_param_buffer.untyped_storage().data_ptr() 

348 ) 

349 

350 def _allocate_bucket_buffers_if_needed(self, device, dtype): 

351 normalized_device = _normalize_device(device) 

352 for bucket in self._active_replicate_buckets.values(): 

353 if bucket.flat_numel == 0: 

354 continue 

355 needs_new_buffer = ( 

356 bucket.buffer is None 

357 or bucket.buffer.numel() != bucket.flat_numel 

358 or bucket.buffer.dtype != dtype 

359 ) 

360 if needs_new_buffer: 

361 bucket.buffer = ms.mint.empty((bucket.flat_numel,), dtype=dtype, device=normalized_device) 

362 

363 def _pack_bucket_from_reduce_output(self, bucket: ReplicateBucket) -> ms.Tensor: 

364 if bucket.buffer is None: 

365 raise AssertionError("Bucket buffer must be allocated before packing from reduce output") 

366 if self._reduce_output is None or self._reduce_hsdp_params is None: 

367 raise AssertionError("Bucket packing requires an active fused reduce output") 

368 dst_offset = 0 

369 for idx in bucket.param_indices: 

370 hsdp_param = self._reduce_hsdp_params[idx] 

371 src_offset = self._active_param_flat_offsets[idx] 

372 numel = _shape_numel(hsdp_param.sharded_size) 

373 bucket.buffer.narrow(0, dst_offset, numel).copy_( 

374 self._reduce_output.narrow(0, src_offset, numel) 

375 ) 

376 dst_offset += numel 

377 return bucket.buffer 

378 

379 def _unpack_bucket_to_reduce_output(self, bucket: ReplicateBucket) -> None: 

380 if bucket.buffer is None: 

381 raise AssertionError("Bucket buffer must exist before unpacking to reduce output") 

382 if self._reduce_output is None or self._reduce_hsdp_params is None: 

383 raise AssertionError("Bucket unpack requires an active fused reduce output") 

384 src_offset = 0 

385 for idx in bucket.param_indices: 

386 hsdp_param = self._reduce_hsdp_params[idx] 

387 dst_offset = self._active_param_flat_offsets[idx] 

388 numel = _shape_numel(hsdp_param.sharded_size) 

389 self._reduce_output.narrow(0, dst_offset, numel).copy_( 

390 bucket.buffer.narrow(0, src_offset, numel) 

391 ) 

392 src_offset += numel 

393 

394 def unshard(self, async_op: bool = False): 

395 """Trigger fused all-gather for all parameters in this group.""" 

396 if self._result is not None: 

397 return 

398 if self.shard_world_size == 1: 

399 self._result = AllGatherResult(None, None, None) 

400 return 

401 self.foreach_all_gather(async_op=async_op) 

402 

403 def _init_mp_dtypes(self): 

404 for hsdp_param in self.hsdp_params: 

405 hsdp_param.init_dtype_attrs(self.mp_policy) 

406 trainable_params: list[MindSporeHSDPParamV2] = [ 

407 p for p in self.hsdp_params if p.sharded_param.requires_grad 

408 ] 

409 orig_dtypes = {p.orig_dtype for p in trainable_params} 

410 reduce_dtypes = {p.reduce_dtype for p in trainable_params} 

411 if len(trainable_params) > 0 and len(orig_dtypes) != 1: 

412 raise AssertionError( 

413 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}" 

414 ) 

415 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None 

416 if len(trainable_params) > 0 and len(reduce_dtypes) != 1: 

417 raise AssertionError( 

418 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}" 

419 ) 

420 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None 

421 

422 def wait_for_unshard(self): 

423 """Wait for fused all-gather and materialize per-parameter unsharded views.""" 

424 if self._result is None: 

425 return 

426 if self.shard_world_size == 1: 

427 for hsdp_param in self.hsdp_params: 

428 all_gather_input = hsdp_param.all_gather_inputs[0] 

429 hsdp_param.init_all_gather_outputs( 

430 [all_gather_input.numel()], 

431 [all_gather_input.dtype], 

432 self.shard_world_size, 

433 _normalize_device(self.device), 

434 ) 

435 hsdp_param.alloc_all_gather_outputs() 

436 copy_without_bumping_version(hsdp_param.all_gather_outputs[0], all_gather_input) 

437 self._result = None 

438 else: 

439 self.foreach_all_gather_copy_out() 

440 for hsdp_param in self.hsdp_params: 

441 hsdp_param.init_unsharded_param() 

442 hsdp_param.to_unsharded() 

443 

444 def alloc_all_gather_output(self, total_output_numel, dtype): 

445 normalized_device = _normalize_device(self.device) 

446 if self.ag_output is None or self.ag_output.dtype != dtype: 

447 self.ag_output = ms.mint.empty((total_output_numel,), dtype=dtype, device=normalized_device) 

448 return 

449 storage = self.ag_output.untyped_storage() 

450 expected_size = total_output_numel * self.ag_output.itemsize 

451 if storage.size() != expected_size: 

452 storage.resize_(expected_size) 

453 

454 def free_all_gather_output(self): 

455 if self.ag_output is None: 

456 return 

457 storage = self.ag_output.untyped_storage() 

458 if storage.size() != 0: 

459 storage.resize_(0) 

460 

461 @_no_grad() 

462 def foreach_all_gather(self, async_op=False): 

463 """Perform one fused all-gather across all parameters in the group.""" 

464 if self.metadata_cache is None: 

465 self.metadata_cache = AllGatherMetadataCache() 

466 metadata = self.metadata_cache.get_metadata(self.hsdp_params, get_all_gather_metadata) 

467 if metadata.total_input_numel == 0: 

468 return 

469 world_size = self.shard_world_size 

470 rank = self.shard_rank 

471 total_output_numel = metadata.total_input_numel * world_size 

472 self.alloc_all_gather_output(total_output_numel, metadata.dtype) 

473 for hsdp_param in self.hsdp_params: 

474 hsdp_param.reset_sharded_param() 

475 if self.enable_zero_copy_param_buffer and not self._is_flat_buffer_valid(): 

476 self._init_flat_param_buffer() 

477 

478 use_flat_buffer = ( 

479 self.enable_zero_copy_param_buffer 

480 and self._flat_param_buffer is not None 

481 and self._is_flat_buffer_valid() 

482 ) 

483 if use_flat_buffer: 

484 if self._flat_cast_buffer is not None: 

485 self._flat_cast_buffer.copy_(self._flat_param_buffer) 

486 all_gather_input = self._flat_cast_buffer 

487 else: 

488 all_gather_input = self._flat_param_buffer 

489 else: 

490 all_gather_inputs = [] 

491 for hsdp_param in self.hsdp_params: 

492 all_gather_inputs.extend(hsdp_param.all_gather_inputs) 

493 if len(all_gather_inputs) == 0: 

494 return 

495 all_gather_input, _ = all_gather_copy_in( 

496 all_gather_inputs, 

497 self.ag_output, 

498 metadata.inp_split_sizes, 

499 metadata.total_input_numel, 

500 rank, 

501 ) 

502 handle = dist.all_gather_into_tensor(self.ag_output, all_gather_input, self.shard_group, async_op) 

503 self._result = AllGatherResult(self.ag_output, metadata, handle) 

504 

505 @_no_grad() 

506 def foreach_all_gather_copy_out(self): 

507 """Scatter one fused all-gather result back into per-parameter buffers.""" 

508 ag_output, metadata, handle = self._result 

509 if handle is not None: 

510 handle.wait() 

511 world_size = self.shard_world_size 

512 split_with_sizes_out = [] 

513 for input_numels, input_dtypes, hsdp_param in zip( 

514 metadata.param_input_numels, metadata.param_input_dtypes, self.hsdp_params 

515 ): 

516 hsdp_param.init_all_gather_outputs( 

517 input_numels, 

518 input_dtypes, 

519 world_size, 

520 _normalize_device(ag_output.device), 

521 ) 

522 hsdp_param.alloc_all_gather_outputs() 

523 split_with_sizes_out.extend(hsdp_param.all_gather_outputs) 

524 ag_output = ag_output.view(world_size, -1) 

525 out = [t.view(world_size, -1) for t in split_with_sizes_out] 

526 split_with_sizes_copy(ag_output, metadata.inp_split_sizes, dim=1, out=out) 

527 self._result = None 

528 self.free_all_gather_output() 

529 

530 @_no_grad() 

531 def foreach_reduce( 

532 self, 

533 reduce_scatter_reduce_op: Optional[ops.ReduceOp] = ops.ReduceOp.SUM, 

534 async_op: bool = True, 

535 needs_avg_div: bool = False, 

536 ) -> Optional[ms.Tensor]: 

537 """Perform fused reduce-scatter and optional bucketed all-reduce.""" 

538 hsdp_params: List[MindSporeHSDPParamV2] = [] 

539 unsharded_grads: List[ms.Tensor] = [] 

540 for hsdp_param in self.hsdp_params: 

541 if not hasattr(hsdp_param, "_unsharded_param"): 

542 continue 

543 if hsdp_param.unsharded_accumulated_grad is not None: 

544 hsdp_params.append(hsdp_param) 

545 unsharded_grads.append(hsdp_param.unsharded_accumulated_grad_data) 

546 elif hsdp_param._unsharded_param.grad is not None: 

547 hsdp_params.append(hsdp_param) 

548 unsharded_grads.append(hsdp_param.unsharded_grad_data) 

549 if not hsdp_params: 

550 return None 

551 grad_dtypes = {g.dtype for g in unsharded_grads} 

552 if len(grad_dtypes) != 1: 

553 raise ValueError( 

554 f"FSDP reduce-scatter expects uniform grad dtype but got {grad_dtypes}" 

555 ) 

556 grad_dtype = unsharded_grads[0].dtype 

557 reduce_dtype = self._reduce_dtype or grad_dtype 

558 world_size = self.shard_world_size 

559 reduce_scatter_input_numel = sum(s.numel() for s in unsharded_grads) 

560 reduce_scatter_output_numel = reduce_scatter_input_numel // world_size 

561 device = _normalize_device(unsharded_grads[0].device) 

562 reduce_scatter_input = ms.mint.empty((reduce_scatter_input_numel,), dtype=reduce_dtype, device=device) 

563 reduce_scatter_copy_in(hsdp_params, unsharded_grads, reduce_scatter_input, world_size) 

564 reduce_output = ms.mint.empty((reduce_scatter_output_numel,), dtype=reduce_dtype, device=device) 

565 self._needs_avg_div = needs_avg_div 

566 self._reduce_op = reduce_scatter_reduce_op 

567 self._reduce_hsdp_params = hsdp_params 

568 self._active_param_flat_offsets = [] 

569 flat_offset = 0 

570 for hsdp_param in hsdp_params: 

571 self._active_param_flat_offsets.append(flat_offset) 

572 flat_offset += _shape_numel(hsdp_param.sharded_size) 

573 self._active_replicate_buckets = self._build_active_replicate_buckets(hsdp_params) 

574 self._allocate_bucket_buffers_if_needed(reduce_output.device, reduce_output.dtype) 

575 self._pending_all_reduce_handles = [] 

576 if self.shard_group is None or world_size <= 1: 

577 comm_ctx.comm_handle = None 

578 self._reduce_output = reduce_scatter_input 

579 if async_op: 

580 comm_ctx.pre_param_group = self 

581 else: 

582 self.apply_fusion_reduced_grad() 

583 return self._reduce_output 

584 rs_handle = dist.reduce_scatter_tensor( 

585 output=reduce_output, 

586 input=reduce_scatter_input, 

587 group=self.shard_group, 

588 op=reduce_scatter_reduce_op, 

589 async_op=async_op, 

590 ) 

591 comm_ctx.comm_handle = rs_handle 

592 self._reduce_output = reduce_output 

593 if async_op: 

594 comm_ctx.pre_param_group = self 

595 else: 

596 self.apply_fusion_reduced_grad() 

597 return reduce_output 

598 

599 def wait_reduce_scatter_and_issue_all_reduce(self): 

600 """Wait for reduce-scatter and issue async all-reduces for active buckets.""" 

601 if comm_ctx.comm_handle is not None: 

602 comm_ctx.comm_handle.wait() 

603 comm_ctx.comm_handle = None 

604 if self._needs_avg_div and self._reduce_output is not None and self.shard_world_size > 1: 

605 self._reduce_output.div_(self.shard_world_size) 

606 if not self._active_replicate_buckets: 

607 self._apply_reduced_grad() 

608 return 

609 self._pending_all_reduce_handles = [] 

610 for bucket in self._active_replicate_buckets.values(): 

611 packed = self._pack_bucket_from_reduce_output(bucket) 

612 ar_handle = dist.all_reduce( 

613 packed, 

614 group=bucket.group, 

615 op=self._reduce_op, 

616 async_op=True, 

617 ) 

618 self._pending_all_reduce_handles.append( 

619 PendingBucketAllReduce(bucket_key=bucket.key, handle=ar_handle) 

620 ) 

621 comm_ctx.all_reduce_param_group = self 

622 

623 def wait_all_reduce_and_apply_grad(self): 

624 """Wait for pending bucket all-reduces and apply reduced grads.""" 

625 for pending in self._pending_all_reduce_handles: 

626 bucket = self._active_replicate_buckets[pending.bucket_key] 

627 pending.handle.wait() 

628 if self._needs_avg_div and bucket.group_size > 1: 

629 bucket.buffer.div_(bucket.group_size) 

630 self._unpack_bucket_to_reduce_output(bucket) 

631 self._pending_all_reduce_handles = [] 

632 comm_ctx.all_reduce_handle = None 

633 self._apply_reduced_grad() 

634 

635 def apply_fusion_reduced_grad(self): 

636 """Synchronous fallback: wait, all-reduce buckets, then apply grads.""" 

637 if comm_ctx.comm_handle is not None: 

638 comm_ctx.comm_handle.wait() 

639 comm_ctx.comm_handle = None 

640 if self._needs_avg_div and self._reduce_output is not None and self.shard_world_size > 1: 

641 self._reduce_output.div_(self.shard_world_size) 

642 for bucket in self._active_replicate_buckets.values(): 

643 packed = self._pack_bucket_from_reduce_output(bucket) 

644 dist.all_reduce( 

645 packed, 

646 group=bucket.group, 

647 op=self._reduce_op, 

648 ) 

649 if self._needs_avg_div and bucket.group_size > 1: 

650 packed.div_(bucket.group_size) 

651 self._unpack_bucket_to_reduce_output(bucket) 

652 self._apply_reduced_grad() 

653 

654 def _apply_reduced_grad(self): 

655 """Write reduced gradients from the fused output buffer back to params.""" 

656 flat_grad_offset = 0 

657 if self._reduce_hsdp_params is None or self._reduce_output is None: 

658 return 

659 for hsdp_param in self._reduce_hsdp_params: 

660 shard_numel = _shape_numel(hsdp_param.sharded_size) 

661 new_sharded_grad = self._reduce_output.narrow(0, flat_grad_offset, shard_numel) 

662 hsdp_param.apply_reduced_grad(new_sharded_grad, self._orig_dtype) 

663 flat_grad_offset += shard_numel 

664 self._reduce_output = None 

665 self._reduce_hsdp_params = None 

666 self._active_param_flat_offsets = [] 

667 self._active_replicate_buckets = {} 

668 self._pending_all_reduce_handles = []