Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / api.py: 8%

310 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""shard""" 

16import inspect 

17from typing import Union, Callable, Dict, List 

18from functools import wraps 

19from hyper_parallel.core.dtensor.layout import Layout, DeviceMesh 

20from hyper_parallel.core.dtensor.dtensor import DTensor, _is_alias_placements 

21from hyper_parallel.core.dtensor.placement_types import Placement 

22from hyper_parallel.core.shard.sharding_plan import ShardingPlan 

23from hyper_parallel.platform import get_platform 

24 

25platform = get_platform() 

26Parameter = platform.Parameter 

27Tensor = platform.Tensor 

28Module = platform.Module 

29 

30 

31def _has_kwargs(func): 

32 """_has_kwargs""" 

33 sig = inspect.signature(func) 

34 return any( 

35 param.default != inspect.Parameter.empty 

36 for param in sig.parameters.values() 

37 ) 

38 

39 

40def _get_param_name(func): 

41 """_get_param_name""" 

42 sig = inspect.signature(func) 

43 return list(sig.parameters.keys()) 

44 

45 

46def _convert_sharding_plan(sharding_plan: Dict, device_mesh: DeviceMesh) -> Dict: 

47 """ 

48 Convert sharding_plan values to Layout objects. 

49 

50 This function recursively traverses the sharding_plan and converts 

51 placement tuples (e.g., (Shard(0), Replicate())) to Layout objects. 

52 

53 Args: 

54 sharding_plan: The original sharding plan with tuple specifications 

55 device_mesh: The DeviceMesh to use for conversion 

56 

57 Returns: 

58 Dict: Converted sharding plan with Layout objects 

59 """ 

60 

61 def _is_placement_tuple(value): 

62 """Check if value is a placement specification tuple. 

63 

64 A placement tuple contains Placement instances (Shard, Replicate) or 

65 alias strings ("dp", "None"). It should NOT be a tuple of placement tuples. 

66 

67 Examples of placement tuples: 

68 (Shard(0), Replicate(), Shard(1)) -> True 

69 ("dp", "tp", "None") -> True 

70 (("dp", "tp"), "None") -> True (multi-axis sharding) 

71 

72 Examples of NON-placement tuples: 

73 ((Shard(0),), (Shard(1),)) -> False (tuple of placement tuples) 

74 """ 

75 if not isinstance(value, tuple) or len(value) == 0: 

76 return False 

77 

78 for item in value: 

79 # Placement instance is valid 

80 if isinstance(item, Placement): 

81 continue 

82 # String (alias name) is valid 

83 if isinstance(item, str): 

84 continue 

85 # Nested tuple needs special handling 

86 if isinstance(item, tuple): 

87 # Nested tuple of strings is valid (multi-axis sharding) 

88 if len(item) > 0 and all(isinstance(x, str) for x in item): 

89 continue 

90 # Nested tuple containing Placement means this is a tuple of placement tuples 

91 if len(item) > 0 and any(isinstance(x, Placement) for x in item): 

92 return False 

93 # Empty tuple or other cases - not valid 

94 return False 

95 # Any other type is not valid in a placement tuple 

96 return False 

97 

98 return True 

99 

100 def _to_layout(value): 

101 """Convert a single sharding specification to Layout.""" 

102 layout = Layout.from_device_mesh(device_mesh) 

103 if _is_alias_placements(value): 

104 result = layout(*value) 

105 else: 

106 result = layout(value) 

107 return result 

108 

109 def _convert_value(value, wrap_single_as_list=False): 

110 """Recursively convert value based on its structure.""" 

111 if value is None: 

112 return None 

113 

114 # Case 1: It's a placement tuple - convert to Layout 

115 if _is_placement_tuple(value): 

116 layout = _to_layout(value) 

117 # Wrap single layout in list if required (for input/output in forward) 

118 return [layout] if wrap_single_as_list else layout 

119 

120 # Case 2: It's a dict - recursively process each value 

121 if isinstance(value, dict): 

122 converted_dict = {} 

123 for k, v in value.items(): 

124 converted_dict[k] = _convert_value(v, wrap_single_as_list=False) 

125 return converted_dict 

126 

127 # Case 3: It's a list - recursively process each element 

128 if isinstance(value, list): 

129 return [_convert_value(v, wrap_single_as_list=False) for v in value] 

130 

131 # Case 4: It's a tuple but not a placement tuple - treat as list 

132 if isinstance(value, tuple): 

133 return [_convert_value(v, wrap_single_as_list=False) for v in value] 

134 

135 # Case 5: Other types (e.g., primitives) - return as is 

136 return value 

137 

138 def _convert_forward_plan(forward_plan): 

139 """Convert forward plan with special handling for input/output.""" 

140 if forward_plan is None: 

141 return None 

142 

143 converted = {} 

144 for key, value in forward_plan.items(): 

145 if key.endswith("input") or key.endswith("output"): 

146 # input/output need special handling: 

147 # - dict format: convert each value, keep as dict 

148 # - list/tuple format: convert each element, keep as list 

149 # - single placement tuple: convert and wrap in list 

150 if value is None: 

151 converted[key] = None 

152 elif isinstance(value, dict): 

153 # Dict format for kwargs: {"x": placements, "activation": placements} 

154 converted[key] = {k: _convert_value(v) for k, v in value.items()} 

155 elif isinstance(value, (list, tuple)): 

156 # Check if it's a single placement tuple or a list/tuple of placement tuples 

157 if _is_placement_tuple(value): 

158 # Single placement tuple - wrap in list 

159 converted[key] = [_to_layout(value)] 

160 else: 

161 # List/tuple of placements for multiple positional args 

162 converted[key] = [_convert_value(v) for v in value] 

163 else: 

164 converted[key] = _convert_value(value, wrap_single_as_list=True) 

165 else: 

166 # Other keys in forward plan 

167 converted[key] = _convert_value(value) 

168 return converted 

169 

170 # Main conversion logic 

171 converted_plan = {} 

172 

173 for key, value in sharding_plan.items(): 

174 if key == "forward": 

175 converted_plan[key] = _convert_forward_plan(value) 

176 elif key.endswith("input") or key.endswith("output"): 

177 # Top-level input/output (for callable sharding) 

178 if value is None: 

179 converted_plan[key] = None 

180 elif isinstance(value, dict): 

181 converted_plan[key] = {k: _convert_value(v) for k, v in value.items()} 

182 elif isinstance(value, (list, tuple)): 

183 if _is_placement_tuple(value): 

184 converted_plan[key] = [_to_layout(value)] 

185 else: 

186 converted_plan[key] = [_convert_value(v) for v in value] 

187 else: 

188 converted_plan[key] = _convert_value(value, wrap_single_as_list=True) 

189 else: 

190 # parameter and other keys - use standard recursive conversion 

191 converted_plan[key] = _convert_value(value) 

192 

193 return converted_plan 

194 

195 

196def _parallel_in(func, args, kwargs, layouts): 

197 """_parallel_in""" 

198 if not isinstance(layouts, (list, dict, tuple)): 

199 raise ValueError(f"The in_layout must be a list, tuple or dict, but got {type(layouts)}.") 

200 

201 params_name = _get_param_name(func) 

202 processed_args = list(args) 

203 processed_kwargs = dict(kwargs) 

204 

205 def _get_layout(index, is_list): 

206 """_get_layout""" 

207 if is_list: 

208 return layouts[index] 

209 param_name = params_name[index] 

210 return layouts[param_name] 

211 

212 is_list = isinstance(layouts, (list, tuple)) 

213 for i, arg in enumerate(args): 

214 if not isinstance(arg, DTensor): 

215 continue 

216 

217 to_layout = _get_layout(i, is_list) 

218 processed_args[i] = arg.redistribute(to_layout.mesh, to_layout.alias_placements) 

219 for k, v in kwargs.items(): 

220 if not isinstance(v, DTensor) or layouts.get(k) is None: 

221 processed_kwargs[k] = v 

222 continue 

223 to_layout = layouts[k] 

224 processed_kwargs[k] = v.redistribute(to_layout.mesh, to_layout.alias_placements) 

225 

226 return tuple(processed_args), processed_kwargs 

227 

228 

229def _parallel_out(outputs, layouts): 

230 """_parallel_out""" 

231 if not isinstance(layouts, (list, tuple)): 

232 raise ValueError(f"The out_layout must be a list or tuple, but got {type(layouts)}.") 

233 if isinstance(outputs, (tuple, list)): 

234 if len(outputs) != len(layouts): 

235 raise ValueError(f"The size of outputs and out_layout must be equal, but got {len(outputs)} and " 

236 f"{len(layouts)}") 

237 new_outputs = [] 

238 for i, arg in enumerate(outputs): 

239 if not isinstance(arg, DTensor) or arg is None: 

240 new_outputs.append(arg) 

241 continue 

242 to_layout = layouts[i] 

243 new_outputs.append(arg.redistribute(to_layout.mesh, to_layout.alias_placements)) 

244 return tuple(new_outputs) 

245 if len(layouts) != 1: 

246 raise ValueError(f"The size of outputs and out_layout must be equal, but got 1 and " 

247 f"{len(layouts)}") 

248 

249 if isinstance(outputs, DTensor): 

250 return outputs.redistribute( 

251 layouts[0].mesh, layouts[0].alias_placements) 

252 return outputs 

253 

254 

255def _forward_pre_hook(cell, args): 

256 """_forward_pre_hook""" 

257 if cell.in_layout is None: 

258 return args 

259 processed_args, _ = _parallel_in(platform.get_cell_construct(cell), args, {}, cell.in_layout) 

260 return processed_args 

261 

262 

263def _forward_pre_with_kwargs_hook(cell, args, kwargs): 

264 """_forward_pre_with_kwargs_hook""" 

265 if cell.in_layout is None: 

266 return args, kwargs 

267 return _parallel_in(platform.get_cell_construct(cell), args, kwargs, cell.in_layout) 

268 

269 

270def _forward_hook(cell, inputs, outputs): # pylint: disable=unused-argument 

271 """_forward_hook""" 

272 if cell.out_layout is None: 

273 return outputs 

274 return _parallel_out(outputs, cell.out_layout) 

275 

276 

277def _forward_with_kwargs_hook(cell, inputs, kwargs, outputs): # pylint: disable=unused-argument 

278 """_forward_with_kwargs_hook""" 

279 return _forward_hook(cell, inputs, outputs) 

280 

281 

282def _register_hook(model: Module, sharding_plan: Dict): 

283 """_register_hook""" 

284 

285 def _register_cell_hook(model, has_inputs_layout, has_outputs_layout): 

286 """_register_cell_hook""" 

287 has_kwargs = _has_kwargs(platform.get_cell_construct(model)) 

288 pre_hook = _forward_pre_with_kwargs_hook if has_kwargs else _forward_pre_hook 

289 hook = _forward_with_kwargs_hook if has_kwargs else _forward_hook 

290 if has_inputs_layout: 

291 model.register_forward_pre_hook(pre_hook, with_kwargs=has_kwargs) 

292 

293 if has_outputs_layout: 

294 model.register_forward_hook(hook, with_kwargs=has_kwargs) 

295 

296 def _set_layouts(model, layouts, set_inputs_layout, set_outputs_layout): 

297 """_set_layouts""" 

298 if set_inputs_layout: 

299 model.in_layout = layouts 

300 

301 if set_outputs_layout: 

302 model.out_layout = layouts 

303 

304 cell_dict = {} 

305 for name, cell in platform.get_cells_and_names(model): 

306 cell_dict[name] = cell 

307 

308 valid_suffix = ["input", "output"] 

309 for key, value in sharding_plan.items(): 

310 if value is None: 

311 continue 

312 has_dot = '.' in key 

313 split_key = key.rsplit('.', 1) 

314 prefix = split_key[0] if has_dot else "" 

315 suffix = split_key[1] if has_dot else key 

316 if suffix not in valid_suffix: 

317 raise ValueError(f"In python shard_module, sharding_plan's forward key must end with input or output, " 

318 f"but got type {suffix}") 

319 

320 set_inputs_layout = suffix == "input" 

321 set_outputs_layout = not set_inputs_layout 

322 register_cell = cell_dict[prefix] 

323 

324 _set_layouts(register_cell, value, set_inputs_layout, set_outputs_layout) 

325 _register_cell_hook(register_cell, set_inputs_layout, set_outputs_layout) 

326 

327 

328def _register_local_tensor_hook(cell: Module, return_local_tensor_list: List[str]): 

329 """_register_local_tensor_hook""" 

330 

331 def hook_func(cell, inputs, outputs): # pylint: disable=unused-argument 

332 def _recursive_to_local(out): 

333 if isinstance(out, (tuple, list)): 

334 new_out = [] 

335 for item in out: 

336 new_out.append(_recursive_to_local(item)) 

337 return tuple(new_out) if isinstance(out, tuple) else new_out 

338 if isinstance(out, DTensor): 

339 return out.to_local() 

340 return out 

341 

342 return _recursive_to_local(outputs) 

343 

344 cell_dict = {} 

345 for name, sub_cell in platform.get_cells_and_names(cell): 

346 cell_dict[name] = sub_cell 

347 

348 for cell_name in return_local_tensor_list: 

349 register_cell = cell_dict[cell_name] 

350 register_cell.register_forward_hook(hook_func) 

351 

352 

353def _shard_callable(func: Callable, sharding_plan: Dict): 

354 """_shard_callable""" 

355 forward_sharding_plan = sharding_plan.get("forward") 

356 if forward_sharding_plan is None: 

357 return func 

358 

359 @wraps(func) 

360 def _shard_wrapper(*args, **kwargs): 

361 """_shard_wrapper""" 

362 input_layout = forward_sharding_plan.get("input") 

363 output_layout = forward_sharding_plan.get("output") 

364 if input_layout is not None: 

365 args, kwargs = _parallel_in(func, args, kwargs, input_layout) 

366 outputs = func(*args, **kwargs) 

367 if output_layout is not None: 

368 outputs = _parallel_out(outputs, output_layout) 

369 return outputs 

370 

371 return _shard_wrapper 

372 

373 

374def shard_module(model: Union[Module, Callable], device_mesh: DeviceMesh, sharding_plan: ShardingPlan): 

375 """ 

376 Defining the input, output and parameters layouts of this cell or Callable. 

377 

378 Note: 

379 - It is valid only in pynative mode. 

380 

381 .. warning:: 

382 The method is currently not supported in Graph mode. 

383 

384 Args: 

385 model (Module or Callable): The model to be sharded. 

386 device_mesh (DeviceMesh): The device mesh for sharding. 

387 sharding_plan (ShardingPlan): Define the layout for the specified parameters, inputs or outputs. 

388 The sharding specification can be: 

389 - tuple of strings for alias format, e.g., ("dp", "None") 

390 - tuple of Placements, e.g., (Shard(0), Replicate()) 

391 

392 Returns: 

393 Module or Callable: The sharded model. 

394 

395 Examples: 

396 >>> # Usage with device_mesh and alias format 

397 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp")) 

398 >>> sharding_plan = ShardingPlan( 

399 ... plan={"mlp.weight": ("None", "tp")}, 

400 ... input_plan={"input": ("dp", "None")}, 

401 ... output_plan={"output": ("dp", "tp")} 

402 ... ) 

403 >>> model = shard_module(model, mesh, sharding_plan) 

404 

405 >>> # Usage with device_mesh and Placement format 

406 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp")) 

407 >>> sharding_plan = ShardingPlan( 

408 ... plan={"mlp.weight": (Replicate(), Shard(1))}, 

409 ... input_plan={"input": (Shard(0), Replicate())}, 

410 ... output_plan={"output": (Shard(0), Shard(1))} 

411 ... ) 

412 >>> model = shard_module(model, mesh, sharding_plan) 

413 """ 

414 if platform.get_world_size() == 1: 

415 return None 

416 

417 if not isinstance(sharding_plan, ShardingPlan): 

418 raise TypeError(f"The 'sharding_plan' must be an instance of ShardingPlan, " 

419 f"but got {type(sharding_plan)}. Direct dict input is not supported.") 

420 

421 normalized_plan = {} 

422 return_local_tensor_list = None 

423 

424 if sharding_plan.plan: 

425 normalized_plan["parameter"] = sharding_plan.plan 

426 

427 forward_part = {} 

428 

429 if sharding_plan.input_plan: 

430 if not isinstance(sharding_plan.input_plan, dict): 

431 raise TypeError(f"input_plan must be a dict, but got {type(sharding_plan.input_plan)}") 

432 forward_part.update(sharding_plan.input_plan) 

433 

434 if sharding_plan.output_plan: 

435 if not isinstance(sharding_plan.output_plan, dict): 

436 raise TypeError(f"output_plan must be a dict, but got {type(sharding_plan.output_plan)}") 

437 forward_part.update(sharding_plan.output_plan) 

438 

439 if forward_part: 

440 normalized_plan["forward"] = forward_part 

441 

442 if sharding_plan.return_local_tensor: 

443 return_local_tensor_list = sharding_plan.return_local_tensor 

444 

445 # Convert sharding_plan to Layout objects 

446 converted_plan = _convert_sharding_plan(normalized_plan, device_mesh) 

447 

448 if not isinstance(model, Module): 

449 return _shard_callable(model, converted_plan) 

450 

451 param_sharding_plan = converted_plan.get("parameter") 

452 forward_sharding_plan = converted_plan.get("forward") 

453 

454 if param_sharding_plan is not None: 

455 for param_name, layout in param_sharding_plan.items(): 

456 if not isinstance(layout, Layout): 

457 raise ValueError(f"In python shard_module, the type of setting in parameter_plan must be Layout, " 

458 f"but got type {type(layout)}") 

459 result = platform.search_parameter_by_name(model, param_name) 

460 if not result: 

461 raise ValueError(f"{param_name} is configured with a layout, but no instance was found.") 

462 _, _, param = result 

463 if layout.tensor_map is None: 

464 layout.placement_to_tensor_map(param.dim()) 

465 param = platform.set_layout_into_parameter(param, layout) 

466 platform.update_parameter_by_name(model, result, param) 

467 

468 if forward_sharding_plan is not None: 

469 _register_hook(model, forward_sharding_plan) 

470 

471 if return_local_tensor_list is not None: 

472 _register_local_tensor_hook(model, return_local_tensor_list) 

473 

474 return model 

475 

476 

477def parallelize_value_and_grad(fn, weights, sens=None): 

478 """ 

479 A wrapper function to generate the function to calculate forward output and gradient for the parallel scenario. 

480 

481 Args: 

482 fn (Union[Cell, Function]): Function to do grad operation. 

483 weights (Union[ParameterTuple, Parameter, list[Parameter]]): 

484 The parameters of the training network that need to 

485 calculate the gradient. `weights` can be got through `weights = net.trainable_params()` . 

486 sens (Union[list(float), tuple(float)], optional): The sensitivity for grad operation. Default: "None". 

487 - If the fn only have one output, the sens must be None, and it will be attached automatically. 

488 - If the fn have multiple outputs: 

489 1) If the sens is None, only handle the first sensitivity, and set the remaining sensitivity to 0. 

490 2) If the sens is not None, the lengths of sens and outputs of fn must be equal. 

491 

492 Returns: 

493 Function, the derivative function used to compute the gradient of a given function. 

494 For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like 

495 `((out1, out2), gradient)` . 

496 

497 Raises: 

498 TypeError: If type of Args does not belong to required ones. 

499 

500 Supported Platforms: 

501 ``Ascend`` 

502 """ 

503 from mindspore import ops # pylint: disable=import-outside-toplevel 

504 grad_fn = ops.GradOperation(get_by_list=True, sens_param=True) 

505 

506 # use CellWrapper to solve two problems: 

507 # 1. avoid running the forward fn or cell twice 

508 # 2. if the input of parallize_value_and_grad is cell and it is directly used as the input for grad, 

509 # the operations before and after its __call__ function will not enter the auto-diff process. 

510 class CellWrapper(Module): 

511 def __init__(self, net): 

512 super().__init__(auto_prefix=False) 

513 self.network = net 

514 

515 def construct(self, *args, **kwargs): 

516 return self.network(*args, **kwargs) 

517 

518 def forward(self, *args, **kwargs): 

519 return self.network(*args, **kwargs) 

520 

521 fn = CellWrapper(fn) 

522 fn.set_grad() # avoid running the forward fn or cell twice 

523 

524 def wrapper(*args, **kwargs): 

525 loss_value = fn(*args, **kwargs) 

526 p_sens = None 

527 

528 if isinstance(loss_value, (list, tuple)): 

529 # There are multiple outputs, requiring multiple sens 

530 p_sens = [] 

531 

532 if sens is None: 

533 # if sens is None, only handle the first sens, and set the remaining sens to 0 

534 loss_0 = loss_value[0] 

535 if isinstance(loss_0, DTensor): 

536 repeat_num = loss_0.layout.repeat_num() 

537 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.local_shape, 1.0 / repeat_num) 

538 else: 

539 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.shape, 1.0) 

540 p_sens.append(sens_0) 

541 

542 for i in range(1, len(loss_value)): 

543 loss_i = loss_value[i] 

544 if isinstance(loss_i, DTensor): 

545 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 0.0) 

546 else: 

547 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 0.0) 

548 p_sens.append(sens_i) 

549 

550 else: 

551 # sens is not None 

552 if not isinstance(sens, list) and not isinstance(sens, tuple): 

553 raise TypeError("if the loss is list or tuple, the sens must be None or list or tuple") 

554 

555 all_float = all(isinstance(item, float) for item in sens) 

556 if not all_float: 

557 raise TypeError("if sens is not None, it should be list of float or tuple of float") 

558 

559 if len(sens) != len(loss_value): 

560 raise TypeError(f"the len of loss is {len(loss_value)}, but the len of sens is {len(sens)}") 

561 

562 for _, loss_i in enumerate(loss_value): 

563 if isinstance(loss_i, DTensor): 

564 repeat_num = loss_i.layout.repeat_num() 

565 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 1.0 / repeat_num) 

566 else: 

567 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 1.0) 

568 p_sens.append(sens_i) 

569 

570 else: 

571 # loss is tensor 

572 if sens is not None: 

573 raise TypeError(f"the fn only have one output, the sens must be None, but it is {sens}") 

574 if isinstance(loss_value, DTensor): 

575 repeat_num = loss_value.layout.repeat_num() 

576 p_sens = ops.fill(ops.DType()(loss_value), loss_value.local_shape, 1.0 / repeat_num) 

577 

578 else: 

579 p_sens = ops.fill(ops.DType()(loss_value), loss_value.shape, 1.0) 

580 

581 grads = grad_fn(fn, weights)(*args, **kwargs, sens=p_sens) 

582 return loss_value, grads 

583 

584 return wrapper