Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / tensor_parallel / style.py: 90%

266 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Parallel styles for declarative tensor-parallel module sharding. 

16 

17Provides :class:`ParallelStyle` (ABC) and concrete implementations 

18:class:`ColwiseParallel`, :class:`RowwiseParallel`, :class:`SequenceParallel`, 

19:class:`PrepareModuleInput`, :class:`PrepareModuleInputOutput`, and 

20:class:`PrepareModuleOutput` aligned with ``torch.distributed.tensor.parallel.style``. 

21""" 

22from abc import ABC, abstractmethod 

23from typing import Any, Dict, Optional, Tuple, Union 

24 

25from hyper_parallel.core.dtensor.device_mesh import DeviceMesh 

26from hyper_parallel.core.dtensor.dtensor import ( 

27 DTensor, 

28 distribute_module, 

29 distribute_tensor, 

30 _distribute_module_iter_params, 

31 _distribute_module_new_parameter, 

32 _distribute_module_param_source, 

33 _distribute_module_set_param, 

34) 

35from hyper_parallel.core.dtensor.placement_types import Partial, Placement, Replicate, Shard 

36from hyper_parallel.platform import get_platform 

37 

38platform = get_platform() 

39Module = platform.Module 

40 

41__all__ = [ 

42 "ParallelStyle", 

43 "ColwiseParallel", 

44 "RowwiseParallel", 

45 "SequenceParallel", 

46 "PrepareModuleInput", 

47 "PrepareModuleInputOutput", 

48 "PrepareModuleOutput", 

49] 

50 

51 

52class ParallelStyle(ABC): 

53 """Abstract base class for parallel styles applied to nn.Module submodules. 

54 

55 Subclasses implement ``apply`` to wrap a module with the desired 

56 parallel communication behaviour (e.g. all-to-all for context parallel). 

57 

58 ``src_data_rank`` mirrors PyTorch's tensor-parallel contract: it can be set by 

59 :func:`parallelize_module` for styles that scatter/broadcast global tensors. 

60 HyperParallel styles may ignore it until they integrate ``distribute_tensor``. 

61 """ 

62 

63 src_data_rank: Optional[int] = 0 

64 

65 @abstractmethod 

66 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

67 """Apply this parallel style to *module* in-place and return it. 

68 

69 Args: 

70 module: The submodule to be parallelised. 

71 device_mesh: The device mesh describing the cluster topology. 

72 

73 Returns: 

74 The (possibly wrapped) module with parallelism applied. 

75 """ 

76 

77 

78class ColwiseParallel(ParallelStyle): 

79 """Partition a compatible module in a column-wise fashion. 

80 

81 Currently supports Linear and Embedding modules (framework-agnostic via 

82 ``platform.is_linear_module`` / ``platform.is_embedding_module``). 

83 Compose with :class:`RowwiseParallel` to shard MLP or Attention blocks. 

84 

85 Keyword Args: 

86 input_layouts (Placement, optional): 

87 DTensor layout for the module input. Used to annotate the input 

88 tensor as a DTensor. Defaults to ``Replicate()``. 

89 output_layouts (Placement, optional): 

90 Desired DTensor layout of the module output. Defaults to 

91 ``Shard(-1)`` (sharded on the last dimension). 

92 use_local_output (bool, optional): 

93 If ``True`` (default), convert the output DTensor back to a local 

94 tensor via ``to_local()``. 

95 

96 Returns: 

97 A :class:`ParallelStyle` that applies column-wise sharding. 

98 

99 Example:: 

100 

101 >>> from hyper_parallel import parallelize_module, ColwiseParallel, init_device_mesh 

102 >>> m = Model(...) 

103 >>> tp_mesh = init_device_mesh("npu", (8,), mesh_dim_names=("tp",)) 

104 >>> parallelize_module(m, tp_mesh, {"linear1": ColwiseParallel()}) 

105 """ 

106 

107 def __init__( 

108 self, 

109 *, 

110 input_layouts: Optional[Placement] = None, 

111 output_layouts: Optional[Placement] = None, 

112 use_local_output: bool = True, 

113 ) -> None: 

114 super().__init__() 

115 self.input_layouts: Tuple[Placement, ...] = (input_layouts or Replicate(),) 

116 self.output_layouts: Tuple[Placement, ...] = (output_layouts or Shard(-1),) 

117 self.desired_input_layouts: Tuple[Placement, ...] = (Replicate(),) 

118 self.use_local_output = use_local_output 

119 

120 def __repr__(self) -> str: 

121 return ( 

122 f"{self.__class__.__name__}(" 

123 f"input_layouts={self.input_layouts}, " 

124 f"output_layouts={self.output_layouts}, " 

125 f"use_local_output={self.use_local_output})" 

126 ) 

127 

128 @staticmethod 

129 def _prepare_input_fn( 

130 input_layouts: Tuple[Placement, ...], 

131 desired_input_layouts: Tuple[Placement, ...], 

132 inputs: Any, 

133 device_mesh: DeviceMesh, 

134 ) -> Any: 

135 """Annotate or redistribute the first positional input.""" 

136 input_tensor = inputs[0] 

137 if not isinstance(input_tensor, DTensor): 

138 input_tensor = DTensor.from_local( 

139 input_tensor, device_mesh, input_layouts, 

140 ) 

141 

142 if input_layouts != desired_input_layouts: 

143 input_tensor = input_tensor.redistribute( 

144 device_mesh, desired_input_layouts, 

145 ) 

146 return input_tensor 

147 

148 def _partition_linear_fn(self, module: Any, device_mesh: DeviceMesh) -> None: 

149 """Shard Linear weight/bias along ``Shard(0)`` (column-wise).""" 

150 for key, param in _distribute_module_iter_params(module): 

151 if param is None: 

152 continue 

153 src = _distribute_module_param_source(param) 

154 requires_grad = bool(getattr(param, "requires_grad", True)) 

155 dt = distribute_tensor(src, device_mesh, [Shard(0)]) 

156 new_param = _distribute_module_new_parameter(key, dt, requires_grad) 

157 _distribute_module_set_param(module, key, new_param) 

158 

159 def _partition_embedding_fn(self, module: Any, device_mesh: DeviceMesh) -> None: 

160 """Shard Embedding weight along ``Shard(1)`` (column-wise).""" 

161 for key, param in _distribute_module_iter_params(module): 

162 if param is None: 

163 continue 

164 src = _distribute_module_param_source(param) 

165 requires_grad = bool(getattr(param, "requires_grad", True)) 

166 dt = distribute_tensor(src, device_mesh, [Shard(1)]) 

167 new_param = _distribute_module_new_parameter(key, dt, requires_grad) 

168 _distribute_module_set_param(module, key, new_param) 

169 

170 @staticmethod 

171 def _prepare_output_fn( 

172 output_layouts: Tuple[Placement, ...], 

173 use_local_output: bool, 

174 outputs: Any, 

175 device_mesh: DeviceMesh, 

176 ) -> Any: 

177 """Redistribute output to desired layout and optionally convert to local.""" 

178 if outputs.placements != output_layouts: 

179 outputs = outputs.redistribute(device_mesh, output_layouts) 

180 if use_local_output: 

181 return outputs.to_local() 

182 return outputs 

183 

184 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

185 """Apply column-wise parallelism to *module*. 

186 

187 Args: 

188 module: A Linear or Embedding module to be sharded. 

189 device_mesh: 1-D device mesh for tensor parallelism. 

190 

191 Returns: 

192 The module with distributed parameters and I/O hooks attached. 

193 

194 Raises: 

195 NotImplementedError: If *module* is not a supported type. 

196 """ 

197 if platform.is_linear_module(module): 

198 

199 def partition_fn(submodule_path, submodule, device_mesh): 

200 self._partition_linear_fn(submodule, device_mesh) 

201 

202 elif platform.is_embedding_module(module): 

203 

204 def partition_fn(submodule_path, submodule, device_mesh): 

205 self._partition_embedding_fn(submodule, device_mesh) 

206 

207 else: 

208 raise NotImplementedError( 

209 "ColwiseParallel currently only supports Linear and Embedding modules!" 

210 ) 

211 

212 def input_fn(forward_module, forward_inputs, device_mesh): 

213 return self._prepare_input_fn( 

214 self.input_layouts, 

215 self.desired_input_layouts, 

216 forward_inputs, 

217 device_mesh, 

218 ) 

219 

220 def output_fn(forward_module, forward_outputs, device_mesh): 

221 return self._prepare_output_fn( 

222 self.output_layouts, 

223 self.use_local_output, 

224 forward_outputs, 

225 device_mesh, 

226 ) 

227 

228 return distribute_module( 

229 module, 

230 device_mesh, 

231 partition_fn, 

232 input_fn, 

233 output_fn, 

234 ) 

235 

236 

237class RowwiseParallel(ParallelStyle): 

238 """Partition a compatible module in a row-wise fashion. 

239 

240 Currently supports Linear and Embedding modules (framework-agnostic via 

241 ``platform.is_linear_module`` / ``platform.is_embedding_module``). 

242 Compose with :class:`ColwiseParallel` to shard MLP or Attention blocks. 

243 

244 Keyword Args: 

245 input_layouts (Placement, optional): 

246 DTensor layout for the module input. Defaults to ``Shard(-1)`` 

247 (sharded on the last dimension). 

248 output_layouts (Placement, optional): 

249 Desired DTensor layout of the module output. Defaults to 

250 ``Replicate()`` (all-reduce / reduce-scatter from partial). 

251 use_local_output (bool, optional): 

252 If ``True`` (default), convert the output DTensor back to a local 

253 tensor via ``to_local()``. 

254 

255 Returns: 

256 A :class:`ParallelStyle` that applies row-wise sharding. 

257 

258 Example:: 

259 >>> from hyper_parallel import parallelize_module, RowwiseParallel, init_device_mesh 

260 >>> m = Model(...) 

261 >>> tp_mesh = init_device_mesh("npu", (8,), mesh_dim_names=("tp",)) 

262 >>> parallelize_module(m, tp_mesh, {"linear2": RowwiseParallel()}) 

263 """ 

264 

265 def __init__( 

266 self, 

267 *, 

268 input_layouts: Optional[Placement] = None, 

269 output_layouts: Optional[Placement] = None, 

270 use_local_output: bool = True, 

271 ) -> None: 

272 super().__init__() 

273 self.input_layouts: Tuple[Placement, ...] = (input_layouts or Shard(-1),) 

274 self.output_layouts: Tuple[Placement, ...] = (output_layouts or Replicate(),) 

275 self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),) 

276 self.use_local_output = use_local_output 

277 

278 def __repr__(self) -> str: 

279 return ( 

280 f"{self.__class__.__name__}(" 

281 f"input_layouts={self.input_layouts}, " 

282 f"output_layouts={self.output_layouts}, " 

283 f"use_local_output={self.use_local_output})" 

284 ) 

285 

286 @staticmethod 

287 def _prepare_input_fn( 

288 input_layouts: Tuple[Placement, ...], 

289 desired_input_layouts: Tuple[Placement, ...], 

290 inputs: Any, 

291 device_mesh: DeviceMesh, 

292 ) -> Any: 

293 """Annotate or redistribute the first positional input.""" 

294 input_tensor = inputs[0] 

295 if not isinstance(input_tensor, DTensor): 

296 input_tensor = DTensor.from_local( 

297 input_tensor, device_mesh, input_layouts, 

298 ) 

299 

300 if input_layouts != desired_input_layouts: 

301 input_tensor = input_tensor.redistribute( 

302 device_mesh, desired_input_layouts, 

303 ) 

304 return input_tensor 

305 

306 def _partition_linear_fn(self, module: Any, device_mesh: DeviceMesh) -> None: 

307 """Shard Linear weight along ``Shard(1)`` (row-wise); bias to ``Replicate()``.""" 

308 for key, param in _distribute_module_iter_params(module): 

309 if param is None: 

310 continue 

311 src = _distribute_module_param_source(param) 

312 requires_grad = bool(getattr(param, "requires_grad", True)) 

313 placement = [Shard(1)] if key == "weight" else [Replicate()] 

314 dt = distribute_tensor(src, device_mesh, placement) 

315 new_param = _distribute_module_new_parameter(key, dt, requires_grad) 

316 _distribute_module_set_param(module, key, new_param) 

317 

318 def _partition_embedding_fn(self, module: Any, device_mesh: DeviceMesh) -> None: 

319 """Shard Embedding weight along ``Shard(0)`` (row-wise).""" 

320 for key, param in _distribute_module_iter_params(module): 

321 if param is None: 

322 continue 

323 src = _distribute_module_param_source(param) 

324 requires_grad = bool(getattr(param, "requires_grad", True)) 

325 dt = distribute_tensor(src, device_mesh, [Shard(0)]) 

326 new_param = _distribute_module_new_parameter(key, dt, requires_grad) 

327 _distribute_module_set_param(module, key, new_param) 

328 

329 @staticmethod 

330 def _prepare_output_fn( 

331 output_layouts: Tuple[Placement, ...], 

332 use_local_output: bool, 

333 outputs: Any, 

334 device_mesh: DeviceMesh, 

335 module: Optional[Module] = None, 

336 ) -> Any: 

337 """Redistribute partial output and optionally convert to local.""" 

338 if not isinstance(outputs, DTensor): 

339 # ``nn.Embedding.forward`` returns a plain tensor even when weight is sharded; 

340 # treat the local values as partial along the TP mesh (sum) before redistributing. 

341 if module is not None and platform.is_embedding_module(module): 

342 outputs = DTensor.from_local(outputs, device_mesh, [Partial("sum")]) 

343 else: 

344 raise TypeError( 

345 "RowwiseParallel expects a DTensor from Linear outputs; " 

346 f"got {type(outputs)}. If this is an unsupported module, extend I/O hooks." 

347 ) 

348 if tuple(outputs.placements) != tuple(output_layouts): 

349 outputs = outputs.redistribute(device_mesh, output_layouts) 

350 if use_local_output: 

351 return outputs.to_local() 

352 return outputs 

353 

354 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

355 """Apply row-wise parallelism to *module*. 

356 

357 Args: 

358 module: A Linear or Embedding module to be sharded. 

359 device_mesh: 1-D device mesh for tensor parallelism. 

360 

361 Returns: 

362 The module with distributed parameters and I/O hooks attached. 

363 

364 Raises: 

365 NotImplementedError: If *module* is not a supported type. 

366 """ 

367 if platform.is_linear_module(module): 

368 

369 def partition_fn(submodule_path, submodule, device_mesh): 

370 self._partition_linear_fn(submodule, device_mesh) 

371 

372 self.desired_input_layouts = (Shard(-1),) 

373 elif platform.is_embedding_module(module): 

374 

375 def partition_fn(submodule_path, submodule, device_mesh): 

376 self._partition_embedding_fn(submodule, device_mesh) 

377 

378 self.desired_input_layouts = (Replicate(),) 

379 else: 

380 raise NotImplementedError( 

381 "RowwiseParallel currently only supports Linear and Embedding modules!" 

382 ) 

383 

384 def input_fn(forward_module, forward_inputs, device_mesh): 

385 return self._prepare_input_fn( 

386 self.input_layouts, 

387 self.desired_input_layouts, 

388 forward_inputs, 

389 device_mesh, 

390 ) 

391 

392 def output_fn(forward_module, forward_outputs, device_mesh): 

393 return self._prepare_output_fn( 

394 self.output_layouts, 

395 self.use_local_output, 

396 forward_outputs, 

397 device_mesh, 

398 forward_module, 

399 ) 

400 

401 return distribute_module( 

402 module, 

403 device_mesh, 

404 partition_fn, 

405 input_fn, 

406 output_fn, 

407 ) 

408 

409 

410class SequenceParallel(ParallelStyle): 

411 """Replicate module parameters and run forward with the sequence axis sharded. 

412 

413 Matches ``torch.distributed.tensor.parallel.SequenceParallel``: activations are 

414 sharded on the sequence dimension while weights stay fully replicated. Typical 

415 targets are normalization and dropout layers used after row-wise / scatter 

416 projections in tensor-parallel transformers (`Reducing Activation Recomputation 

417 in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__). 

418 

419 If the first positional input is a plain tensor, it is treated as the local 

420 shard along ``sequence_dim`` and wrapped as a :class:`DTensor`. If it is already 

421 a :class:`DTensor` but not sharded on that dimension, it is redistributed. 

422 

423 Keyword Args: 

424 sequence_dim (int, optional): 

425 Tensor dimension index for the sequence axis (e.g. ``1`` for ``(B, S, H)``). 

426 Default: ``1``. 

427 use_local_output (bool, optional): 

428 If ``True``, return a local tensor via ``to_local()``; otherwise keep a 

429 :class:`DTensor`. Default: ``False`` (PyTorch default). 

430 

431 Note: 

432 Like PyTorch, this assumes sensible defaults for norm weights (e.g. ones). 

433 Custom initializations should be broadcast so every rank agrees before or 

434 after parallelization. 

435 

436 Example:: 

437 

438 >>> from hyper_parallel import parallelize_module, SequenceParallel, init_device_mesh 

439 >>> m = Model(...) 

440 >>> tp_mesh = init_device_mesh("npu", (8,), mesh_dim_names=("tp",)) 

441 >>> parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}) 

442 """ 

443 

444 def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False) -> None: 

445 super().__init__() 

446 self.sequence_sharding: Tuple[Placement, ...] = (Shard(sequence_dim),) 

447 self.use_local_output = use_local_output 

448 

449 def __repr__(self) -> str: 

450 dim = self.sequence_sharding[0].dim 

451 return ( 

452 f"{self.__class__.__name__}(" 

453 f"sequence_dim={dim}, " 

454 f"use_local_output={self.use_local_output})" 

455 ) 

456 

457 @staticmethod 

458 def _prepare_input_fn( 

459 sequence_sharding: Tuple[Placement, ...], 

460 mod: Module, 

461 inputs: Any, 

462 device_mesh: DeviceMesh, 

463 ) -> Any: 

464 """Ensure the first input is a :class:`DTensor` sharded on the sequence dim.""" 

465 input_tensor = inputs[0] 

466 if isinstance(input_tensor, DTensor): 

467 if tuple(input_tensor.placements) != tuple(sequence_sharding): 

468 input_tensor = input_tensor.redistribute(device_mesh, sequence_sharding) 

469 return input_tensor 

470 if platform.is_tensor(input_tensor): 

471 return DTensor.from_local(input_tensor, device_mesh, sequence_sharding) 

472 raise ValueError( 

473 f"expecting input of {mod} to be a tensor or DTensor, but got {type(input_tensor)}" 

474 ) 

475 

476 @staticmethod 

477 def _prepare_output_fn(use_local_output: bool, outputs: Any) -> Any: 

478 if use_local_output: 

479 return outputs.to_local() 

480 return outputs 

481 

482 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

483 """Apply sequence-parallel hooks and replicate parameters via ``distribute_module``. 

484 

485 Args: 

486 module: Submodule to parallelize (for example ``LayerNorm`` or ``Dropout``). 

487 device_mesh: One-dimensional tensor-parallel device mesh. 

488 

489 Returns: 

490 The same ``module`` instance with forward hooks attached and parameters 

491 converted to replicated DTensors where applicable. 

492 """ 

493 

494 def partition_fn(_submodule_path, _submodule, _mesh): 

495 return None 

496 

497 def input_fn(forward_module, forward_inputs, mesh): 

498 return self._prepare_input_fn( 

499 self.sequence_sharding, 

500 forward_module, 

501 forward_inputs, 

502 mesh, 

503 ) 

504 

505 def output_fn(_forward_module, forward_outputs, _mesh): 

506 return self._prepare_output_fn(self.use_local_output, forward_outputs) 

507 

508 return distribute_module( 

509 module, 

510 device_mesh, 

511 partition_fn, 

512 input_fn, 

513 output_fn, 

514 ) 

515 

516 

517class PrepareModuleInput(ParallelStyle): 

518 """Prepare module forward *args* (and optional *kwargs*) as :class:`DTensor` layouts. 

519 

520 At forward time, converts each annotated positional (or keyword) tensor from local 

521 to :class:`DTensor` using ``input_layouts``, then redistributes to 

522 ``desired_input_layouts`` when they differ. ``None`` in a layout tuple means 

523 “leave this input unchanged”. 

524 

525 Mirrors ``torch.distributed.tensor.parallel.style.PrepareModuleInput``. 

526 

527 Keyword Args: 

528 input_layouts: Placements per positional arg, or a single :class:`Placement` 

529 wrapped as a one-tuple. ``None`` entries skip conversion for that arg. 

530 desired_input_layouts: Target placements; must match ``input_layouts`` length. 

531 input_kwarg_layouts: Optional mapping kwarg name → placement for conversion. 

532 desired_input_kwarg_layouts: Target placements for those kwargs (same keys). 

533 use_local_output: If ``True``, convert prepared inputs back to local tensors 

534 before the module runs (PyTorch names this flag ``use_local_output`` on 

535 :class:`PrepareModuleInput`). 

536 """ 

537 

538 def __init__( 

539 self, 

540 *, 

541 input_layouts: Optional[Union[Placement, Tuple[Optional[Placement], ...]]] = None, 

542 desired_input_layouts: Optional[ 

543 Union[Placement, Tuple[Optional[Placement], ...]] 

544 ] = None, 

545 input_kwarg_layouts: Optional[Dict[str, Placement]] = None, 

546 desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, 

547 use_local_output: bool = False, 

548 ) -> None: 

549 super().__init__() 

550 self.input_layouts = ( 

551 (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts 

552 ) 

553 self.desired_input_layouts = ( 

554 (desired_input_layouts,) 

555 if isinstance(desired_input_layouts, Placement) 

556 else desired_input_layouts 

557 ) 

558 self.use_local_output = use_local_output 

559 if self.input_layouts is not None: 

560 if self.desired_input_layouts is None: 

561 raise AssertionError("desired module inputs should not be None!") 

562 if len(self.input_layouts) != len(self.desired_input_layouts): 

563 raise AssertionError( 

564 "input_layouts and desired_input_layouts should have same length!" 

565 ) 

566 self.with_kwargs = input_kwarg_layouts is not None 

567 self.input_kwarg_layouts = input_kwarg_layouts or {} 

568 self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} 

569 if self.with_kwargs: 

570 if len(self.input_kwarg_layouts) != len(self.desired_input_kwarg_layouts): 

571 raise AssertionError( 

572 "input_kwarg_layouts and desired_input_kwarg_layouts should have " 

573 "same length!" 

574 ) 

575 

576 def _prepare_input_arg( 

577 self, 

578 input_obj: Any, 

579 mesh: DeviceMesh, 

580 input_layout: Optional[Placement], 

581 desired_layout: Optional[Placement], 

582 ) -> Any: 

583 """Convert one input to DTensor, redistribute if needed, optionally to_local.""" 

584 if input_layout is not None: 

585 if isinstance(input_obj, DTensor): 

586 dt_inp = input_obj 

587 else: 

588 if not platform.is_tensor(input_obj): 

589 raise AssertionError("expecting input to be a framework tensor!") 

590 dt_inp = DTensor.from_local(input_obj, mesh, (input_layout,)) 

591 

592 if desired_layout is not None and input_layout != desired_layout: 

593 dt_inp = dt_inp.redistribute(mesh, (desired_layout,)) 

594 

595 return dt_inp.to_local() if self.use_local_output else dt_inp 

596 return input_obj 

597 

598 def _prepare_input_fn(self, inputs: Any, device_mesh: DeviceMesh) -> Any: 

599 """Prepare positional ``inputs`` tuple per ``input_layouts`` / ``desired_input_layouts``.""" 

600 if self.input_layouts is None: 

601 return inputs 

602 if not isinstance(inputs, tuple): 

603 inputs = (inputs,) 

604 if len(inputs) != len(self.input_layouts): 

605 raise ValueError("module inputs and input_layouts should have same length!") 

606 if self.desired_input_layouts is None: 

607 raise AssertionError("desired module inputs should not be None!") 

608 prepared_inputs = [ 

609 self._prepare_input_arg(inp, device_mesh, il, dl) 

610 for inp, il, dl in zip(inputs, self.input_layouts, self.desired_input_layouts) 

611 ] 

612 return tuple(prepared_inputs) 

613 

614 def _prepare_input_kwarg_fn( 

615 self, 

616 inputs: Any, 

617 kwarg_inputs: Dict[str, Any], 

618 device_mesh: DeviceMesh, 

619 ) -> Tuple[Any, Dict[str, Any]]: 

620 """Prepare positional and keyword tensor inputs; returns ``(args, kwargs)`` for the hook.""" 

621 prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) 

622 prepared_kwarg_inputs: Dict[str, Any] = {} 

623 for kwarg_key in kwarg_inputs: 

624 kwarg_val = kwarg_inputs[kwarg_key] 

625 input_layout = self.input_kwarg_layouts.get(kwarg_key) 

626 desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) 

627 prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( 

628 kwarg_val, device_mesh, input_layout, desired_input_layout 

629 ) 

630 return (prepared_arg_inputs, prepared_kwarg_inputs) 

631 

632 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

633 if self.with_kwargs: 

634 

635 def _pre_hook(_mod, inputs, kwargs): 

636 return self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh) 

637 

638 platform.register_forward_pre_hook( 

639 module, _pre_hook, prepend=False, with_kwargs=True, 

640 ) 

641 else: 

642 

643 def _pre_hook(_mod, inputs): 

644 return self._prepare_input_fn(inputs, device_mesh) 

645 

646 platform.register_forward_pre_hook(module, _pre_hook, prepend=False) 

647 return module 

648 

649 def __repr__(self) -> str: 

650 return ( 

651 f"{self.__class__.__name__}(" 

652 f"input_layouts={self.input_layouts}, " 

653 f"desired_input_layouts={self.desired_input_layouts}, " 

654 f"input_kwarg_layouts={self.input_kwarg_layouts}, " 

655 f"desired_input_kwarg_layouts={self.desired_input_kwarg_layouts}, " 

656 f"use_local_output={self.use_local_output})" 

657 ) 

658 

659 

660class PrepareModuleOutput(ParallelStyle): 

661 """Prepare module forward outputs as :class:`DTensor` and redistribute layouts. 

662 

663 Registers a forward hook that treats each return value like 

664 ``torch.distributed.tensor.parallel.style.PrepareModuleOutput``: optional 

665 ``None`` slots in ``output_layouts`` pass that output through unchanged. 

666 

667 Keyword Args: 

668 output_layouts: Current or assumed placement per output tensor. 

669 desired_output_layouts: Target placements; length must match ``output_layouts``. 

670 use_local_output: If ``True`` (default), return local shards after redistribution. 

671 """ 

672 

673 def __init__( 

674 self, 

675 *, 

676 output_layouts: Union[Placement, Tuple[Optional[Placement], ...]], 

677 desired_output_layouts: Union[Placement, Tuple[Optional[Placement], ...]], 

678 use_local_output: bool = True, 

679 ) -> None: 

680 super().__init__() 

681 self.output_layouts = ( 

682 (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts 

683 ) 

684 self.desired_output_layouts = ( 

685 (desired_output_layouts,) 

686 if isinstance(desired_output_layouts, Placement) 

687 else desired_output_layouts 

688 ) 

689 self.use_local_output = use_local_output 

690 if len(self.output_layouts) != len(self.desired_output_layouts): 

691 raise AssertionError( 

692 "output_layouts and desired_output_layouts should have same length!" 

693 ) 

694 

695 def _prepare_out_fn(self, outputs: Any, device_mesh: DeviceMesh) -> Any: 

696 """Redistribute each output tensor per ``output_layouts`` / ``desired_output_layouts``.""" 

697 prepared_outputs: list = [] 

698 if not isinstance(outputs, tuple): 

699 outputs = (outputs,) 

700 if len(outputs) != len(self.output_layouts): 

701 raise ValueError("module outputs and output_layouts should have same length!") 

702 for out, out_layout, desired_out_layout in zip( 

703 outputs, self.output_layouts, self.desired_output_layouts, 

704 ): 

705 if out_layout is not None: 

706 if isinstance(out, DTensor): 

707 dt_out = out 

708 else: 

709 dt_out = DTensor.from_local(out, device_mesh, (out_layout,)) 

710 if out_layout != desired_out_layout: 

711 dt_out = dt_out.redistribute(device_mesh, (desired_out_layout,)) 

712 prepared_outputs.append( 

713 dt_out.to_local() if self.use_local_output else dt_out 

714 ) 

715 else: 

716 prepared_outputs.append(out) 

717 if len(prepared_outputs) == 1: 

718 return prepared_outputs[0] 

719 return tuple(prepared_outputs) 

720 

721 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

722 

723 def _hook(_mod, _inputs, outputs): 

724 return self._prepare_out_fn(outputs, device_mesh) 

725 

726 module.register_forward_hook(_hook) 

727 return module 

728 

729 def __repr__(self) -> str: 

730 return ( 

731 f"{self.__class__.__name__}(" 

732 f"output_layouts={self.output_layouts}, " 

733 f"desired_output_layouts={self.desired_output_layouts}, " 

734 f"use_local_output={self.use_local_output})" 

735 ) 

736 

737 

738class PrepareModuleInputOutput(ParallelStyle): 

739 """Combine :class:`PrepareModuleInput` and :class:`PrepareModuleOutput` on one module. 

740 

741 Same keyword arguments as the two styles, with ``use_local_input`` mapping to 

742 ``PrepareModuleInput(..., use_local_output=use_local_input)`` for PyTorch parity. 

743 """ 

744 

745 def __init__( 

746 self, 

747 *, 

748 input_layouts: Optional[Union[Placement, Tuple[Optional[Placement], ...]]] = None, 

749 desired_input_layouts: Optional[ 

750 Union[Placement, Tuple[Optional[Placement], ...]] 

751 ] = None, 

752 input_kwarg_layouts: Optional[Dict[str, Placement]] = None, 

753 desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, 

754 use_local_input: bool = False, 

755 output_layouts: Union[Placement, Tuple[Optional[Placement], ...]], 

756 desired_output_layouts: Union[Placement, Tuple[Optional[Placement], ...]], 

757 use_local_output: bool = True, 

758 ) -> None: 

759 super().__init__() 

760 self.prepare_module_input = PrepareModuleInput( 

761 input_layouts=input_layouts, 

762 desired_input_layouts=desired_input_layouts, 

763 input_kwarg_layouts=input_kwarg_layouts, 

764 desired_input_kwarg_layouts=desired_input_kwarg_layouts, 

765 use_local_output=use_local_input, 

766 ) 

767 self.prepare_module_output = PrepareModuleOutput( 

768 output_layouts=output_layouts, 

769 desired_output_layouts=desired_output_layouts, 

770 use_local_output=use_local_output, 

771 ) 

772 

773 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module: 

774 self.prepare_module_input.apply(module, device_mesh) 

775 self.prepare_module_output.apply(module, device_mesh) 

776 return module 

777 

778 def __repr__(self) -> str: 

779 p_in = self.prepare_module_input 

780 p_out = self.prepare_module_output 

781 return ( 

782 f"{self.__class__.__name__}(" 

783 f"input_layouts={p_in.input_layouts}, " 

784 f"desired_input_layouts={p_in.desired_input_layouts}, " 

785 f"input_kwarg_layouts={p_in.input_kwarg_layouts}, " 

786 f"desired_input_kwarg_layouts={p_in.desired_input_kwarg_layouts}, " 

787 f"use_local_input={p_in.use_local_output}, " 

788 f"output_layouts={p_out.output_layouts}, " 

789 f"desired_output_layouts={p_out.desired_output_layouts}, " 

790 f"use_local_output={p_out.use_local_output})" 

791 )