Coverage for hyper_parallel / core / shard / api.py: 77%

305 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""shard""" 

16import inspect 

17from typing import Union, Callable, Dict, List 

18from functools import wraps 

19from hyper_parallel.core.layout import Layout, DeviceMesh 

20from hyper_parallel.core.dtensor import DTensor 

21from hyper_parallel.core.placement_types import Placement 

22from hyper_parallel.core.shard.sharding_plan import ShardingPlan 

23from hyper_parallel.platform import get_platform 

24 

25platform = get_platform() 

26Parameter = platform.Parameter 

27Tensor = platform.Tensor 

28Module = platform.Module 

29 

30 

31def _has_kwargs(func): 

32 """_has_kwargs""" 

33 sig = inspect.signature(func) 

34 return any( 

35 param.default != inspect.Parameter.empty 

36 for param in sig.parameters.values() 

37 ) 

38 

39 

40def _get_param_name(func): 

41 """_get_param_name""" 

42 sig = inspect.signature(func) 

43 return list(sig.parameters.keys()) 

44 

45 

46def _convert_sharding_plan(sharding_plan: Dict, device_mesh: DeviceMesh) -> Dict: 

47 """ 

48 Convert sharding_plan values to Layout objects. 

49 

50 This function recursively traverses the sharding_plan and converts 

51 placement tuples (e.g., (Shard(0), Replicate())) to Layout objects. 

52 

53 Args: 

54 sharding_plan: The original sharding plan with tuple specifications 

55 device_mesh: The DeviceMesh to use for conversion 

56 

57 Returns: 

58 Dict: Converted sharding plan with Layout objects 

59 """ 

60 

61 def _is_placement_tuple(value): 

62 """Check if value is a placement specification tuple. 

63 

64 A placement tuple contains Placement instances (Shard, Replicate) or 

65 alias strings ("dp", "None"). It should NOT be a tuple of placement tuples. 

66 

67 Examples of placement tuples: 

68 (Shard(0), Replicate(), Shard(1)) -> True 

69 ("dp", "tp", "None") -> True 

70 (("dp", "tp"), "None") -> True (multi-axis sharding) 

71 

72 Examples of NON-placement tuples: 

73 ((Shard(0),), (Shard(1),)) -> False (tuple of placement tuples) 

74 """ 

75 if not isinstance(value, tuple) or len(value) == 0: 

76 return False 

77 

78 for item in value: 

79 # Placement instance is valid 

80 if isinstance(item, Placement): 

81 continue 

82 # String (alias name) is valid 

83 if isinstance(item, str): 

84 continue 

85 # Nested tuple needs special handling 

86 if isinstance(item, tuple): 

87 # Nested tuple of strings is valid (multi-axis sharding) 

88 if len(item) > 0 and all(isinstance(x, str) for x in item): 

89 continue 

90 # Nested tuple containing Placement means this is a tuple of placement tuples 

91 if len(item) > 0 and any(isinstance(x, Placement) for x in item): 

92 return False 

93 # Empty tuple or other cases - not valid 

94 return False 

95 # Any other type is not valid in a placement tuple 

96 return False 

97 

98 return True 

99 

100 def _to_layout(value): 

101 """Convert a single sharding specification to Layout.""" 

102 layout = Layout.from_device_mesh(device_mesh) 

103 result = layout(value) 

104 return result 

105 

106 def _convert_value(value, wrap_single_as_list=False): 

107 """Recursively convert value based on its structure.""" 

108 if value is None: 

109 return None 

110 

111 # Case 1: It's a placement tuple - convert to Layout 

112 if _is_placement_tuple(value): 

113 layout = _to_layout(value) 

114 # Wrap single layout in list if required (for input/output in forward) 

115 return [layout] if wrap_single_as_list else layout 

116 

117 # Case 2: It's a dict - recursively process each value 

118 if isinstance(value, dict): 

119 converted_dict = {} 

120 for k, v in value.items(): 

121 converted_dict[k] = _convert_value(v, wrap_single_as_list=False) 

122 return converted_dict 

123 

124 # Case 3: It's a list - recursively process each element 

125 if isinstance(value, list): 

126 return [_convert_value(v, wrap_single_as_list=False) for v in value] 

127 

128 # Case 4: It's a tuple but not a placement tuple - treat as list 

129 if isinstance(value, tuple): 

130 return [_convert_value(v, wrap_single_as_list=False) for v in value] 

131 

132 # Case 5: Other types (e.g., primitives) - return as is 

133 return value 

134 

135 def _convert_forward_plan(forward_plan): 

136 """Convert forward plan with special handling for input/output.""" 

137 if forward_plan is None: 

138 return None 

139 

140 converted = {} 

141 for key, value in forward_plan.items(): 

142 if key.endswith("input") or key.endswith("output"): 

143 # input/output need special handling: 

144 # - dict format: convert each value, keep as dict 

145 # - list/tuple format: convert each element, keep as list 

146 # - single placement tuple: convert and wrap in list 

147 if value is None: 

148 converted[key] = None 

149 elif isinstance(value, dict): 

150 # Dict format for kwargs: {"x": placements, "activation": placements} 

151 converted[key] = {k: _convert_value(v) for k, v in value.items()} 

152 elif isinstance(value, (list, tuple)): 

153 # Check if it's a single placement tuple or a list/tuple of placement tuples 

154 if _is_placement_tuple(value): 

155 # Single placement tuple - wrap in list 

156 converted[key] = [_to_layout(value)] 

157 else: 

158 # List/tuple of placements for multiple positional args 

159 converted[key] = [_convert_value(v) for v in value] 

160 else: 

161 converted[key] = _convert_value(value, wrap_single_as_list=True) 

162 else: 

163 # Other keys in forward plan 

164 converted[key] = _convert_value(value) 

165 return converted 

166 

167 # Main conversion logic 

168 converted_plan = {} 

169 

170 for key, value in sharding_plan.items(): 

171 if key == "forward": 

172 converted_plan[key] = _convert_forward_plan(value) 

173 elif key.endswith("input") or key.endswith("output"): 

174 # Top-level input/output (for callable sharding) 

175 if value is None: 

176 converted_plan[key] = None 

177 elif isinstance(value, dict): 

178 converted_plan[key] = {k: _convert_value(v) for k, v in value.items()} 

179 elif isinstance(value, (list, tuple)): 

180 if _is_placement_tuple(value): 

181 converted_plan[key] = [_to_layout(value)] 

182 else: 

183 converted_plan[key] = [_convert_value(v) for v in value] 

184 else: 

185 converted_plan[key] = _convert_value(value, wrap_single_as_list=True) 

186 else: 

187 # parameter and other keys - use standard recursive conversion 

188 converted_plan[key] = _convert_value(value) 

189 

190 return converted_plan 

191 

192 

193def _parallel_in(func, args, kwargs, layouts): 

194 """_parallel_in""" 

195 if not isinstance(layouts, (list, dict, tuple)): 

196 raise ValueError(f"The in_layout must be a list, tuple or dict, but got {type(layouts)}.") 

197 

198 params_name = _get_param_name(func) 

199 processed_args = list(args) 

200 processed_kwargs = dict(kwargs) 

201 

202 def _get_layout(index, is_list): 

203 """_get_layout""" 

204 if is_list: 

205 return layouts[index] 

206 param_name = params_name[index] 

207 return layouts[param_name] 

208 

209 is_list = isinstance(layouts, (list, tuple)) 

210 for i, arg in enumerate(args): 

211 if not isinstance(arg, DTensor): 

212 continue 

213 

214 to_layout = _get_layout(i, is_list) 

215 processed_args[i] = arg.redistribute(to_layout.mesh, to_layout.placements) 

216 for k, v in kwargs.items(): 

217 if not isinstance(v, DTensor) or layouts.get(k) is None: 

218 processed_kwargs[k] = v 

219 continue 

220 to_layout = layouts[k] 

221 processed_kwargs[k] = v.redistribute(to_layout.mesh, to_layout.placements) 

222 

223 return tuple(processed_args), processed_kwargs 

224 

225 

226def _parallel_out(outputs, layouts): 

227 """_parallel_out""" 

228 if not isinstance(layouts, (list, tuple)): 

229 raise ValueError(f"The out_layout must be a list or tuple, but got {type(layouts)}.") 

230 if isinstance(outputs, (tuple, list)): 

231 if len(outputs) != len(layouts): 

232 raise ValueError(f"The size of outputs and out_layout must be equal, but got {len(outputs)} and " 

233 f"{len(layouts)}") 

234 new_outputs = [] 

235 for i, arg in enumerate(outputs): 

236 if not isinstance(arg, DTensor) or arg is None: 

237 new_outputs.append(arg) 

238 continue 

239 to_layout = layouts[i] 

240 new_outputs.append(arg.redistribute(to_layout.mesh, to_layout.placements)) 

241 return tuple(new_outputs) 

242 if len(layouts) != 1: 

243 raise ValueError(f"The size of outputs and out_layout must be equal, but got 1 and " 

244 f"{len(layouts)}") 

245 

246 return outputs.redistribute(layouts[0].mesh, layouts[0].placements) if isinstance(outputs, DTensor) else outputs 

247 

248 

249def _forward_pre_hook(cell, args): 

250 """_forward_pre_hook""" 

251 if cell.in_layout is None: 

252 return args 

253 processed_args, _ = _parallel_in(platform.get_cell_construct(cell), args, {}, cell.in_layout) 

254 return processed_args 

255 

256 

257def _forward_pre_with_kwargs_hook(cell, args, kwargs): 

258 """_forward_pre_with_kwargs_hook""" 

259 if cell.in_layout is None: 

260 return args, kwargs 

261 return _parallel_in(platform.get_cell_construct(cell), args, kwargs, cell.in_layout) 

262 

263 

264def _forward_hook(cell, inputs, outputs): # pylint: disable=unused-argument 

265 """_forward_hook""" 

266 if cell.out_layout is None: 

267 return outputs 

268 return _parallel_out(outputs, cell.out_layout) 

269 

270 

271def _forward_with_kwargs_hook(cell, inputs, kwargs, outputs): # pylint: disable=unused-argument 

272 """_forward_with_kwargs_hook""" 

273 return _forward_hook(cell, inputs, outputs) 

274 

275 

276def _register_hook(model: Module, sharding_plan: Dict): 

277 """_register_hook""" 

278 

279 def _register_cell_hook(model, has_inputs_layout, has_outputs_layout): 

280 """_register_cell_hook""" 

281 has_kwargs = _has_kwargs(platform.get_cell_construct(model)) 

282 pre_hook = _forward_pre_with_kwargs_hook if has_kwargs else _forward_pre_hook 

283 hook = _forward_with_kwargs_hook if has_kwargs else _forward_hook 

284 if has_inputs_layout: 

285 model.register_forward_pre_hook(pre_hook, with_kwargs=has_kwargs) 

286 

287 if has_outputs_layout: 

288 model.register_forward_hook(hook, with_kwargs=has_kwargs) 

289 

290 def _set_layouts(model, layouts, set_inputs_layout, set_outputs_layout): 

291 """_set_layouts""" 

292 if set_inputs_layout: 

293 model.in_layout = layouts 

294 

295 if set_outputs_layout: 

296 model.out_layout = layouts 

297 

298 cell_dict = {} 

299 for name, cell in platform.get_cells_and_names(model): 

300 cell_dict[name] = cell 

301 

302 valid_suffix = ["input", "output"] 

303 for key, value in sharding_plan.items(): 

304 if value is None: 

305 continue 

306 has_dot = '.' in key 

307 split_key = key.rsplit('.', 1) 

308 prefix = split_key[0] if has_dot else "" 

309 suffix = split_key[1] if has_dot else key 

310 if suffix not in valid_suffix: 

311 raise ValueError(f"In python shard_module, sharding_plan's forward key must end with input or output, " 

312 f"but got type {suffix}") 

313 

314 set_inputs_layout = suffix == "input" 

315 set_outputs_layout = not set_inputs_layout 

316 register_cell = cell_dict[prefix] 

317 

318 _set_layouts(register_cell, value, set_inputs_layout, set_outputs_layout) 

319 _register_cell_hook(register_cell, set_inputs_layout, set_outputs_layout) 

320 

321 

322def _register_local_tensor_hook(cell: Module, return_local_tensor_list: List[str]): 

323 """_register_local_tensor_hook""" 

324 

325 def hook_func(cell, inputs, outputs): # pylint: disable=unused-argument 

326 def _recursive_to_local(out): 

327 if isinstance(out, (tuple, list)): 

328 new_out = [] 

329 for item in out: 

330 new_out.append(_recursive_to_local(item)) 

331 return tuple(new_out) if isinstance(out, tuple) else new_out 

332 if isinstance(out, DTensor): 

333 return out.to_local() 

334 return out 

335 

336 return _recursive_to_local(outputs) 

337 

338 cell_dict = {} 

339 for name, sub_cell in platform.get_cells_and_names(cell): 

340 cell_dict[name] = sub_cell 

341 

342 for cell_name in return_local_tensor_list: 

343 register_cell = cell_dict[cell_name] 

344 register_cell.register_forward_hook(hook_func) 

345 

346 

347def _shard_callable(func: Callable, sharding_plan: Dict): 

348 """_shard_callable""" 

349 forward_sharding_plan = sharding_plan.get("forward") 

350 if forward_sharding_plan is None: 

351 return func 

352 

353 @wraps(func) 

354 def _shard_wrapper(*args, **kwargs): 

355 """_shard_wrapper""" 

356 input_layout = forward_sharding_plan.get("input") 

357 output_layout = forward_sharding_plan.get("output") 

358 if input_layout is not None: 

359 args, kwargs = _parallel_in(func, args, kwargs, input_layout) 

360 outputs = func(*args, **kwargs) 

361 if output_layout is not None: 

362 outputs = _parallel_out(outputs, output_layout) 

363 return outputs 

364 

365 return _shard_wrapper 

366 

367 

368def shard_module(model: Union[Module, Callable], device_mesh: DeviceMesh, sharding_plan: ShardingPlan): 

369 """ 

370 Defining the input, output and parameters layouts of this cell or Callable. 

371 

372 Note: 

373 - It is valid only in pynative mode. 

374 

375 .. warning:: 

376 The method is currently not supported in Graph mode. 

377 

378 Args: 

379 model (Module or Callable): The model to be sharded. 

380 device_mesh (DeviceMesh): The device mesh for sharding. 

381 sharding_plan (ShardingPlan): Define the layout for the specified parameters, inputs or outputs. 

382 The sharding specification can be: 

383 - tuple of strings for alias format, e.g., ("dp", "None") 

384 - tuple of Placements, e.g., (Shard(0), Replicate()) 

385 

386 Returns: 

387 Module or Callable: The sharded model. 

388 

389 Examples: 

390 >>> # Usage with device_mesh and alias format 

391 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp")) 

392 >>> sharding_plan = ShardingPlan( 

393 ... plan={"mlp.weight": ("None", "tp")}, 

394 ... input_plan={"input": ("dp", "None")}, 

395 ... output_plan={"output": ("dp", "tp")} 

396 ... ) 

397 >>> model = shard_module(model, mesh, sharding_plan) 

398 

399 >>> # Usage with device_mesh and Placement format 

400 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp")) 

401 >>> sharding_plan = ShardingPlan( 

402 ... plan={"mlp.weight": (Replicate(), Shard(1))}, 

403 ... input_plan={"input": (Shard(0), Replicate())}, 

404 ... output_plan={"output": (Shard(0), Shard(1))} 

405 ... ) 

406 >>> model = shard_module(model, mesh, sharding_plan) 

407 """ 

408 if platform.get_world_size() == 1: 

409 return None 

410 

411 if not isinstance(sharding_plan, ShardingPlan): 

412 raise TypeError(f"The 'sharding_plan' must be an instance of ShardingPlan, " 

413 f"but got {type(sharding_plan)}. Direct dict input is not supported.") 

414 

415 normalized_plan = {} 

416 return_local_tensor_list = None 

417 

418 if sharding_plan.plan: 

419 normalized_plan["parameter"] = sharding_plan.plan 

420 

421 forward_part = {} 

422 

423 if sharding_plan.input_plan: 

424 if not isinstance(sharding_plan.input_plan, dict): 

425 raise TypeError(f"input_plan must be a dict, but got {type(sharding_plan.input_plan)}") 

426 forward_part.update(sharding_plan.input_plan) 

427 

428 if sharding_plan.output_plan: 

429 if not isinstance(sharding_plan.output_plan, dict): 

430 raise TypeError(f"output_plan must be a dict, but got {type(sharding_plan.output_plan)}") 

431 forward_part.update(sharding_plan.output_plan) 

432 

433 if forward_part: 

434 normalized_plan["forward"] = forward_part 

435 

436 if sharding_plan.return_local_tensor: 

437 return_local_tensor_list = sharding_plan.return_local_tensor 

438 

439 # Convert sharding_plan to Layout objects 

440 converted_plan = _convert_sharding_plan(normalized_plan, device_mesh) 

441 

442 if not isinstance(model, Module): 

443 return _shard_callable(model, converted_plan) 

444 

445 param_sharding_plan = converted_plan.get("parameter") 

446 forward_sharding_plan = converted_plan.get("forward") 

447 

448 if param_sharding_plan is not None: 

449 for param_name, layout in param_sharding_plan.items(): 

450 if not isinstance(layout, Layout): 

451 raise ValueError(f"In python shard_module, the type of setting in parameter_plan must be Layout, " 

452 f"but got type {type(layout)}") 

453 result = platform.search_parameter_by_name(model, param_name) 

454 if not result: 

455 raise ValueError(f"{param_name} is configured with a layout, but no instance was found.") 

456 _, _, param = result 

457 layout.placement_to_tensor_map(param.dim()) 

458 param = platform.set_layout_into_parameter(param, layout) 

459 platform.update_parameter_by_name(model, result, param) 

460 

461 if forward_sharding_plan is not None: 

462 _register_hook(model, forward_sharding_plan) 

463 

464 if return_local_tensor_list is not None: 

465 _register_local_tensor_hook(model, return_local_tensor_list) 

466 

467 return model 

468 

469 

470def parallelize_value_and_grad(fn, weights, sens=None): 

471 """ 

472 A wrapper function to generate the function to calculate forward output and gradient for the parallel scenario. 

473 

474 Args: 

475 fn (Union[Cell, Function]): Function to do grad operation. 

476 weights (Union[ParameterTuple, Parameter, list[Parameter]]): 

477 The parameters of the training network that need to 

478 calculate the gradient. `weights` can be got through `weights = net.trainable_params()` . 

479 sens (Union[list(float), tuple(float)], optional): The sensitivity for grad operation. Default: "None". 

480 - If the fn only have one output, the sens must be None, and it will be attached automatically. 

481 - If the fn have multiple outputs: 

482 1) If the sens is None, only handle the first sensitivity, and set the remaining sensitivity to 0. 

483 2) If the sens is not None, the lengths of sens and outputs of fn must be equal. 

484 

485 Returns: 

486 Function, the derivative function used to compute the gradient of a given function. 

487 For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like 

488 `((out1, out2), gradient)` . 

489 

490 Raises: 

491 TypeError: If type of Args does not belong to required ones. 

492 

493 Supported Platforms: 

494 ``Ascend`` 

495 """ 

496 from mindspore import ops # pylint: disable=import-outside-toplevel 

497 grad_fn = ops.GradOperation(get_by_list=True, sens_param=True) 

498 

499 # use CellWrapper to solve two problems: 

500 # 1. avoid running the forward fn or cell twice 

501 # 2. if the input of parallize_value_and_grad is cell and it is directly used as the input for grad, 

502 # the operations before and after its __call__ function will not enter the auto-diff process. 

503 class CellWrapper(Module): 

504 def __init__(self, net): 

505 super().__init__(auto_prefix=False) 

506 self.network = net 

507 

508 def construct(self, *args, **kwargs): 

509 return self.network(*args, **kwargs) 

510 

511 def forward(self, *args, **kwargs): 

512 return self.network(*args, **kwargs) 

513 

514 fn = CellWrapper(fn) 

515 fn.set_grad() # avoid running the forward fn or cell twice 

516 

517 def wrapper(*args, **kwargs): 

518 loss_value = fn(*args, **kwargs) 

519 p_sens = None 

520 

521 if isinstance(loss_value, (list, tuple)): 

522 # There are multiple outputs, requiring multiple sens 

523 p_sens = [] 

524 

525 if sens is None: 

526 # if sens is None, only handle the first sens, and set the remaining sens to 0 

527 loss_0 = loss_value[0] 

528 if isinstance(loss_0, DTensor): 

529 repeat_num = loss_0.layout.repeat_num() 

530 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.local_shape, 1.0 / repeat_num) 

531 else: 

532 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.shape, 1.0) 

533 p_sens.append(sens_0) 

534 

535 for i in range(1, len(loss_value)): 

536 loss_i = loss_value[i] 

537 if isinstance(loss_i, DTensor): 

538 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 0.0) 

539 else: 

540 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 0.0) 

541 p_sens.append(sens_i) 

542 

543 else: 

544 # sens is not None 

545 if not isinstance(sens, list) and not isinstance(sens, tuple): 

546 raise TypeError("if the loss is list or tuple, the sens must be None or list or tuple") 

547 

548 all_float = all(isinstance(item, float) for item in sens) 

549 if not all_float: 

550 raise TypeError("if sens is not None, it should be list of float or tuple of float") 

551 

552 if len(sens) != len(loss_value): 

553 raise TypeError(f"the len of loss is {len(loss_value)}, but the len of sens is {len(sens)}") 

554 

555 for _, loss_i in enumerate(loss_value): 

556 if isinstance(loss_i, DTensor): 

557 repeat_num = loss_i.layout.repeat_num() 

558 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 1.0 / repeat_num) 

559 else: 

560 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 1.0) 

561 p_sens.append(sens_i) 

562 

563 else: 

564 # loss is tensor 

565 if sens is not None: 

566 raise TypeError(f"the fn only have one output, the sens must be None, but it is {sens}") 

567 if isinstance(loss_value, DTensor): 

568 repeat_num = loss_value.layout.repeat_num() 

569 p_sens = ops.fill(ops.DType()(loss_value), loss_value.local_shape, 1.0 / repeat_num) 

570 

571 else: 

572 p_sens = ops.fill(ops.DType()(loss_value), loss_value.shape, 1.0) 

573 

574 grads = grad_fn(fn, weights)(*args, **kwargs, sens=p_sens) 

575 return loss_value, grads 

576 

577 return wrapper