Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / state.py: 60%

311 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore HSDP cell state""" 

16from typing import Optional 

17import mindspore as ms 

18from mindspore import ops 

19import mindspore.mint.distributed as dist 

20from hyper_parallel.core.fully_shard.hsdp_state import HSDPState 

21from hyper_parallel.core.fully_shard.hsdp_utils import ( 

22 _get_param_module_infos, 

23 FullyShardParamMode, 

24 infer_fully_shard_param_mode, 

25) 

26from hyper_parallel.platform.mindspore.fully_shard.pack_utils import build_rs_plan 

27from hyper_parallel.platform.mindspore.fully_shard.param import MindSporeHSDPParamV2 

28from hyper_parallel.platform.mindspore.fully_shard._version_utils import copy_without_bumping_version 

29from hyper_parallel.platform.mindspore.fully_shard.param_group import HSDPParamGroup, get_comm_ctx 

30from hyper_parallel.platform.mindspore.utils import normalize_runtime_device 

31from hyper_parallel.core.fully_shard.utils import CPUOffloadPolicy 

32 

33 

34def _to_dtype_if_needed( 

35 tensor: ms.Tensor, dtype: Optional[ms.Type] 

36) -> ms.Tensor: 

37 """Cast tensor to the given dtype if it differs from current dtype. 

38 

39 Args: 

40 tensor: The input tensor to potentially cast. 

41 dtype: Target dtype. If None or same as tensor dtype, no-op. 

42 """ 

43 if dtype is not None and tensor.dtype != dtype: 

44 return tensor.to(dtype) 

45 return tensor 

46 

47 

48class MindSporeHSDPStateV2(HSDPState): 

49 """MindSpore HSDP cell state""" 

50 # DTensor compat parameters in pure-TP mode can accumulate gradients 

51 # directly on ``sharded_param.grad`` without materializing an 

52 # ``_unsharded_param``. Track those async all-reduces separately from the 

53 # standard unsharded-gradient queues. 

54 pre_direct_all_reduce_grads = [] 

55 

56 @staticmethod 

57 def _get_pending_unsharded_grad(hsdp_param): 

58 """Return the pending unsharded gradient tensor for reduction paths.""" 

59 if hsdp_param.unsharded_accumulated_grad is not None: 

60 return hsdp_param.unsharded_accumulated_grad_data 

61 return hsdp_param.unsharded_grad_data 

62 

63 @staticmethod 

64 def _has_pending_unsharded_grad(hsdp_param): 

65 """Whether the parameter currently has a gradient waiting for reduction.""" 

66 if hsdp_param.unsharded_accumulated_grad is not None: 

67 return True 

68 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

69 return False 

70 return hsdp_param.unsharded_param.grad is not None 

71 

72 @staticmethod 

73 def _get_local_sharded_grad(hsdp_param): 

74 """Return the local gradient tensor currently stored on ``sharded_param``.""" 

75 grad = hsdp_param.sharded_param.grad 

76 if grad is None: 

77 return None 

78 to_local = getattr(grad, "to_local", None) 

79 if callable(to_local): 

80 return to_local() 

81 return grad 

82 

83 @staticmethod 

84 def _synchronize_current_stream_if_needed(need_synchronize: bool) -> None: 

85 """Synchronize the current device stream after non-blocking CPU offload.""" 

86 if not need_synchronize: 

87 return 

88 ms.runtime.current_stream().synchronize() 

89 

90 def __init__(self, cell, mesh_info, config, platform, device=None): 

91 super().__init__(cell, mesh_info, config, platform, device) 

92 self.comm_fusion = config.comm_fusion 

93 # Do ReduceScatter/AllReduce for grad 

94 self.mp_policy = config.mp_policy 

95 self.offload_policy = config.offload_policy 

96 self.reduce_grads = True 

97 # Reshard parameter after backward 

98 self.reshard_after_backward = True 

99 # Requires AllReduce for grad When HSDP 

100 self.requires_all_reduce = True 

101 # Keep historical AVG behavior for local parameters while DTensor-aware 

102 # paths default to SUM semantics without extra division. 

103 self.reduce_op_type = ops.ReduceOp.SUM 

104 self._need_div = not any( 

105 getattr(param, "param_mode", FullyShardParamMode.LOCAL_PARAM) 

106 != FullyShardParamMode.LOCAL_PARAM 

107 for param in self._iter_managed_params() 

108 ) 

109 self._ignored_allreduce_works = [] 

110 self._reset_sharded_params = False 

111 self._init_param_group() 

112 

113 def _iter_managed_params(self): 

114 """Return all fully_shard-managed parameters, including replicate_params.""" 

115 return [*self.hsdp_params, *self.replicate_params] 

116 

117 @staticmethod 

118 def _comm_fusion_unsupported_reason(hsdp_param) -> Optional[str]: 

119 """Return the reason why ``hsdp_param`` cannot participate in comm_fusion.""" 

120 if not hsdp_param.enable_fsdp_shard: 

121 return "non-sharded parameters such as replicate_params are not supported" 

122 if hsdp_param.param_mode not in ( 

123 FullyShardParamMode.LOCAL_PARAM, 

124 FullyShardParamMode.DTENSOR_UNIFIED, 

125 ): 

126 return f"param_mode {hsdp_param.param_mode} is not supported" 

127 local_shard = getattr(hsdp_param, "_sharded_local_tensor", None) 

128 if local_shard is None: 

129 return "missing local shard tensor for comm_fusion plan validation" 

130 plan_world_size = getattr(hsdp_param, "shard_world_size", None) 

131 if plan_world_size is None: 

132 plan_world_size = getattr(hsdp_param, "shard_size", 1) 

133 try: 

134 build_rs_plan(hsdp_param, local_shard, plan_world_size) 

135 except NotImplementedError as exc: 

136 return str(exc) 

137 except (AssertionError, ValueError) as exc: 

138 return f"cannot build comm_fusion pack plan: {exc}" 

139 return None 

140 

141 def _init_param_group(self): 

142 """Initialize fused parameter group when comm_fusion is enabled.""" 

143 if self.config.comm_fusion: 

144 unsupported_param = next( 

145 ( 

146 hsdp_param 

147 for hsdp_param in self.hsdp_params 

148 if self._comm_fusion_unsupported_reason(hsdp_param) is not None 

149 ), 

150 None, 

151 ) 

152 if unsupported_param is not None: 

153 param_fqn = getattr(unsupported_param, "_param_fqn", "<unknown>") 

154 reason = self._comm_fusion_unsupported_reason(unsupported_param) 

155 raise NotImplementedError( 

156 f"comm_fusion does not support parameter {param_fqn}: {reason}." 

157 ) 

158 self.param_group = None 

159 if self.hsdp_params: 

160 self.param_group = HSDPParamGroup( 

161 self.hsdp_params, 

162 self.mesh_info, 

163 self.device, 

164 self.mp_policy, 

165 self.config.comm_fusion_zero_copy, 

166 ) 

167 

168 def zero_grad(self): 

169 """zero grad""" 

170 for hsdp_param in self.hsdp_params: 

171 hsdp_param.zero_grad() 

172 for hsdp_param in self.replicate_params: 

173 hsdp_param.zero_grad() 

174 

175 @staticmethod 

176 def _div_if_needed(x, divisor, need_div: bool): 

177 """Apply gradient averaging only when the caller-provided policy requires it. 

178 

179 ``need_div`` may come from the current state or from metadata captured when 

180 async reduce work was queued, so this helper is safe for both immediate and 

181 deferred gradient materialization paths. 

182 """ 

183 if not need_div: 

184 return 

185 if divisor == 1: 

186 return 

187 x.div_(divisor) 

188 

189 def _move_states_to_device(self): 

190 """move states to device""" 

191 for mod in self.modules: 

192 for param in mod.get_parameters(): 

193 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

194 continue 

195 param_device = normalize_runtime_device(param.device) 

196 if param_device in (self.device, "meta"): 

197 continue 

198 param.data = param.to(self.device) 

199 for buffer in mod.buffers(): 

200 if buffer.device in (self.device, "meta"): 

201 continue 

202 buffer.data = buffer.to(self.device) 

203 

204 def _init_hsdp_params(self): 

205 """init hsdp parameters for cell and replicate parameters for cell.""" 

206 # all parameters in the module tree(s), deduplicated 

207 visited_params = set() 

208 replicate_params = set(self.config.replicate_params or ()) 

209 ignored_params = set(self.config.ignored_params or ()) 

210 filtered_params = [] 

211 for mod in self.modules: 

212 for _, param in mod.parameters_and_names(): 

213 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

214 continue 

215 if param in ignored_params: 

216 continue 

217 if param in visited_params: 

218 continue 

219 visited_params.add(param) 

220 filtered_params.append(param) 

221 

222 module_infos = _get_param_module_infos(filtered_params, tuple(self.modules)) 

223 for param, module_info in zip(filtered_params, module_infos): 

224 param_mode = infer_fully_shard_param_mode(self.config.mesh, [param]) 

225 enable_fsdp_shard = param not in replicate_params 

226 hsdp_param = MindSporeHSDPParamV2( 

227 param, 

228 module_info, 

229 self.mesh_info, 

230 shard_placement_fn=self.config.shard_placement_fn, 

231 mp_policy=self.mp_policy, 

232 offload_policy=self.offload_policy, 

233 device=self.device, 

234 param_mode=param_mode, 

235 enable_fsdp_shard=enable_fsdp_shard, 

236 ) 

237 if param in replicate_params: 

238 self.replicate_params.append(hsdp_param) 

239 else: 

240 self.hsdp_params.append(hsdp_param) 

241 if hsdp_param.is_sharded: 

242 self.sharded_hsdp_params.append(hsdp_param) 

243 

244 def _init_mp_dtypes(self): 

245 """init mp dtypes for hsdp parameters and replicate parameters""" 

246 for hsdp_param in self.hsdp_params: 

247 hsdp_param.init_dtype_attrs(self.mp_policy) 

248 for replicate_param in self.replicate_params: 

249 replicate_param.init_dtype_attrs(self.mp_policy) 

250 trainable_params: list[MindSporeHSDPParamV2] = [ 

251 p for p in self._iter_managed_params() if p.sharded_param.requires_grad 

252 ] 

253 orig_dtypes = {p.orig_dtype for p in trainable_params} 

254 reduce_dtypes = {p.reduce_dtype for p in trainable_params} 

255 if len(trainable_params) > 0 and len(orig_dtypes) != 1: 

256 raise AssertionError( 

257 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}" 

258 ) 

259 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None 

260 if len(trainable_params) > 0 and len(reduce_dtypes) != 1: 

261 raise AssertionError( 

262 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}" 

263 ) 

264 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None 

265 

266 def lazy_init(self): 

267 """Refresh parameter views and validate runtime state before first execution.""" 

268 if not self._reset_sharded_params: 

269 for hsdp_param in self.hsdp_params: 

270 if hsdp_param.is_sharded: 

271 hsdp_param.reset_sharded_param() 

272 self._reset_sharded_params = True 

273 self._validate_no_meta_params() 

274 self._validate_cpu_offload_params() 

275 self._init_mp_dtypes() 

276 

277 def _validate_cpu_offload_params(self): 

278 """Validate that all parameters are on CPU when CPU offload policy is enabled.""" 

279 if not isinstance(self.offload_policy, CPUOffloadPolicy): 

280 return 

281 hsdp_params_not_on_cpu = [ 

282 hsdp_param 

283 for hsdp_param in self._iter_managed_params() 

284 if not str(hsdp_param.sharded_param.device).lower().startswith("cpu") 

285 ] 

286 if hsdp_params_not_on_cpu: 

287 raise RuntimeError( 

288 "HSDP parameters should be materialized on CPU when enabling CPU offloading. " 

289 "For example, load a CPU state dict before training. " 

290 "Found following parameters on non-CPU device: " 

291 f"{[(p._param_fqn, p.sharded_param.device) for p in hsdp_params_not_on_cpu]}\n" 

292 ) 

293 

294 def _validate_no_meta_params(self): 

295 """Validate that all parameters have been materialized from meta device.""" 

296 param_names_on_meta = [ 

297 hsdp_param._param_fqn 

298 for hsdp_param in self._iter_managed_params() 

299 if hsdp_param.sharded_param.device == "meta" 

300 ] 

301 if param_names_on_meta: 

302 raise RuntimeError( 

303 "HSDP parameters should be materialized from meta device before training, " 

304 f"but the following were still on meta device: {param_names_on_meta}\n" 

305 "For example, initialize the module weights on a real device before running training." 

306 ) 

307 

308 def _allreduce_replicate_params(self, async_op=True) -> None: 

309 """ 

310 DDP-style all-reduce for parameters in config.replicate_params. 

311 

312 Use the parameter's layout-driven unsharded group so DTensor-aware 

313 compatibility and unified modes reduce over the correct axes. 

314 """ 

315 for param in self.replicate_params: 

316 if not hasattr(param, "_unsharded_param") or param.unsharded_param is None: 

317 continue 

318 if ( 

319 param.unsharded_accumulated_grad is None 

320 and param.unsharded_param.grad is None 

321 ): 

322 continue 

323 

324 reduced_grad = param.unsharded_accumulated_grad_data 

325 if reduced_grad is None: 

326 reduced_grad = param.unsharded_grad_data 

327 reduced_grad = _to_dtype_if_needed(reduced_grad, self._reduce_dtype) 

328 reduce_group_info = getattr(param, "unsharded_group_info", None) 

329 reduce_group = reduce_group_info.group if reduce_group_info is not None else None 

330 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1 

331 

332 if reduce_group is not None and reduce_group_size > 1: 

333 param.all_reduce_handle = dist.all_reduce( 

334 reduced_grad, group=reduce_group, op=self.reduce_op_type, async_op=async_op 

335 ) 

336 self._ignored_allreduce_works.append((param, reduced_grad, reduce_group_size)) 

337 

338 def _finish_ignored_allreduce(self) -> None: 

339 """ 

340 Wait for async all-reduce of replicate_params and materialize param.grad. 

341 

342 For each pending work, this: 

343 Waits on all associated handles to complete; 

344 Casts reduced_grad back to _orig_dtype if needed; 

345 Assigns the final tensor to param.grad. 

346 """ 

347 if not self._ignored_allreduce_works: 

348 return 

349 

350 need_synchronize = False 

351 for param, reduced_grad, reduce_group_size in self._ignored_allreduce_works: 

352 if param.all_reduce_handle: 

353 param.all_reduce_handle.wait() 

354 self._div_if_needed(reduced_grad, reduce_group_size, self._need_div) 

355 need_synchronize = ( 

356 param.apply_reduced_grad(reduced_grad, self._orig_dtype) 

357 or need_synchronize 

358 ) 

359 

360 self._synchronize_current_stream_if_needed(need_synchronize) 

361 self._ignored_allreduce_works.clear() 

362 

363 def reduce_params(self): 

364 """Drain pending sharded parameter reductions and materialize sharded grads.""" 

365 need_synchronize = False 

366 while HSDPState.pre_reduce_scatter_params: 

367 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_reduce_scatter_params.pop(0) 

368 reduced_grad = hsdp_param.reduce_scatter_output() 

369 self._div_if_needed(reduced_grad, hsdp_param.shard_world_size, need_div) 

370 hsdp_param.clear_reduce_scatter_output() 

371 need_synchronize = ( 

372 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) 

373 or need_synchronize 

374 ) 

375 

376 while HSDPState.pre_all_reduce_params: 

377 hsdp_param, pre_orig_dtype, need_div = HSDPState.pre_all_reduce_params.pop(0) 

378 reduced_grad = hsdp_param.all_reduce_output() 

379 self._div_if_needed(reduced_grad, hsdp_param.replicate_world_size, need_div) 

380 hsdp_param.clear_all_reduce_output() 

381 need_synchronize = ( 

382 hsdp_param.apply_reduced_grad(reduced_grad, pre_orig_dtype) 

383 or need_synchronize 

384 ) 

385 while MindSporeHSDPStateV2.pre_direct_all_reduce_grads: 

386 handle, reduced_grad, target_grad, reduce_group_size, need_div = ( 

387 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.pop(0) 

388 ) 

389 if handle is not None: 

390 handle.wait() 

391 self._div_if_needed(reduced_grad, reduce_group_size, need_div) 

392 if reduced_grad is not target_grad: 

393 if reduced_grad.dtype != target_grad.dtype: 

394 reduced_grad = reduced_grad.to(target_grad.dtype) 

395 copy_without_bumping_version(target_grad, reduced_grad) 

396 self._synchronize_current_stream_if_needed(need_synchronize) 

397 

398 def post_backward_for_comm_fusion(self): 

399 """Drive the fused gradient-reduction pipeline for sharded params.""" 

400 self.reduce_params() 

401 comm_ctx = get_comm_ctx() 

402 if comm_ctx.all_reduce_param_group is not None: 

403 comm_ctx.all_reduce_param_group.wait_all_reduce_and_apply_grad() 

404 comm_ctx.all_reduce_param_group = None 

405 if comm_ctx.pre_param_group is not None: 

406 comm_ctx.pre_param_group.wait_reduce_scatter_and_issue_all_reduce() 

407 comm_ctx.pre_param_group = None 

408 if self.param_group is not None: 

409 self.param_group.foreach_reduce( 

410 reduce_scatter_reduce_op=self.reduce_op_type, 

411 needs_avg_div=self._need_div, 

412 ) 

413 self._allreduce_replicate_params() 

414 

415 def _post_backward_without_reduce(self): 

416 """Finish backward when gradient communication is disabled.""" 

417 if self.reshard_after_backward: 

418 self.shard() 

419 for hsdp_param in self._iter_managed_params(): 

420 hsdp_param.to_accumulated_grad_if_needed() 

421 

422 def _should_run_all_reduce(self, hsdp_param) -> bool: 

423 """Whether the current parameter should issue an all-reduce in this backward pass.""" 

424 return self.requires_all_reduce and hsdp_param.dp_size > 1 

425 

426 def _queue_reduce_scatter_then_all_reduce(self, hsdp_param): 

427 """Queue the standard FSDP/HSDP reduction path.""" 

428 hsdp_param.reduce_scatter_grad( 

429 async_op=True, 

430 dtype=self._reduce_dtype, 

431 reduce_op=self.reduce_op_type 

432 ) 

433 HSDPState.pre_reduce_scatter_params.append((hsdp_param, self._orig_dtype, self._need_div)) 

434 if not self._should_run_all_reduce(hsdp_param): 

435 return 

436 reduced_grad = hsdp_param.reduce_scatter_output() 

437 if ( 

438 HSDPState.pre_reduce_scatter_params 

439 and HSDPState.pre_reduce_scatter_params[-1][0] == hsdp_param 

440 ): 

441 HSDPState.pre_reduce_scatter_params.pop() 

442 hsdp_param.clear_reduce_scatter_output() 

443 self._div_if_needed(reduced_grad, hsdp_param.shard_size, self._need_div) 

444 hsdp_param.all_reduce_grad( 

445 grad=reduced_grad, 

446 dtype=self._reduce_dtype, 

447 async_op=True, 

448 reduce_op=self.reduce_op_type, 

449 ) 

450 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div)) 

451 

452 def _queue_compat_all_reduce(self, hsdp_param): 

453 """Queue the compatibility all-reduce path without FSDP sharding.""" 

454 if not self._should_run_all_reduce(hsdp_param): 

455 return 

456 hsdp_param.all_reduce_grad( 

457 grad=self._get_pending_unsharded_grad(hsdp_param), 

458 dtype=self._reduce_dtype, 

459 async_op=True, 

460 reduce_op=self.reduce_op_type, 

461 ) 

462 HSDPState.pre_all_reduce_params.append((hsdp_param, self._orig_dtype, self._need_div)) 

463 

464 def _can_direct_all_reduce_compat_grad(self, hsdp_param) -> bool: 

465 """Whether ``hsdp_param`` should reduce its existing ``sharded_param.grad`` directly.""" 

466 return ( 

467 hsdp_param.param_mode == FullyShardParamMode.DTENSOR_COMPAT 

468 and hsdp_param.enable_fsdp_shard 

469 and not hsdp_param.is_sharded 

470 and hsdp_param.shard_size == 1 

471 and hsdp_param.sharded_param.requires_grad 

472 and self._should_run_all_reduce(hsdp_param) 

473 and self._get_local_sharded_grad(hsdp_param) is not None 

474 ) 

475 

476 def _queue_direct_compat_all_reduce(self, hsdp_param): 

477 """Queue all-reduce for DTENSOR_COMPAT params whose grad stays on ``sharded_param``.""" 

478 grad = self._get_local_sharded_grad(hsdp_param) 

479 if grad is None: 

480 return 

481 reduced_grad = _to_dtype_if_needed(grad, self._reduce_dtype) 

482 reduce_group_info = getattr(hsdp_param, "unsharded_group_info", None) 

483 reduce_group = reduce_group_info.group if reduce_group_info is not None else None 

484 reduce_group_size = reduce_group_info.rank_size if reduce_group_info is not None else 1 

485 handle = None 

486 if reduce_group_size > 1: 

487 if reduce_group is None: 

488 raise RuntimeError("Expected a valid unsharded all-reduce group when rank_size > 1") 

489 handle = dist.all_reduce( 

490 reduced_grad, 

491 group=reduce_group, 

492 op=self.reduce_op_type, 

493 async_op=True, 

494 ) 

495 MindSporeHSDPStateV2.pre_direct_all_reduce_grads.append( 

496 (handle, reduced_grad, grad, reduce_group_size, self._need_div) 

497 ) 

498 

499 def post_backward(self, *_): 

500 for hsdp_param in self._iter_managed_params(): 

501 hsdp_param.accumulate_unsharded_grad_if_needed() 

502 if not self.reduce_grads: 

503 self._post_backward_without_reduce() 

504 return 

505 if not self.comm_fusion: 

506 self.reduce_params() 

507 self._allreduce_replicate_params() 

508 for hsdp_param in self.hsdp_params: 

509 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

510 if self._can_direct_all_reduce_compat_grad(hsdp_param): 

511 self._queue_direct_compat_all_reduce(hsdp_param) 

512 continue 

513 if not hsdp_param.sharded_param.requires_grad: 

514 continue 

515 if not self._has_pending_unsharded_grad(hsdp_param): 

516 continue 

517 if hsdp_param.shard_size > 1: 

518 self._queue_reduce_scatter_then_all_reduce(hsdp_param) 

519 elif self._should_run_all_reduce(hsdp_param): 

520 self._queue_compat_all_reduce(hsdp_param) 

521 else: 

522 need_synchronize = hsdp_param.apply_reduced_grad( 

523 self._get_pending_unsharded_grad(hsdp_param), 

524 self._orig_dtype, 

525 ) 

526 self._synchronize_current_stream_if_needed(need_synchronize) 

527 self._finish_ignored_allreduce() 

528 else: 

529 self.post_backward_for_comm_fusion() 

530 if self.reshard_after_backward: 

531 self.shard() 

532 

533 def set_requires_grad_sync(self, requires_grad_sync): 

534 """set requires grad sync flag to control gradient sync.""" 

535 self.reduce_grads = requires_grad_sync 

536 

537 def set_reduce_op_type(self, reduce_op_type: str): 

538 """set reduce op type for gradient reduction.""" 

539 fsdp_support_reduce_op = { 

540 "sum": ops.ReduceOp.SUM, 

541 "avg": ops.ReduceOp.SUM, 

542 } 

543 if reduce_op_type not in fsdp_support_reduce_op: 

544 raise ValueError( 

545 f"Unsupported reduce op type {reduce_op_type}, " 

546 f"supported types are {list(fsdp_support_reduce_op.keys())}") 

547 self._need_div = reduce_op_type == "avg" 

548 reduce_op: str = reduce_op_type.lower().strip() 

549 self.reduce_op_type = fsdp_support_reduce_op[reduce_op]