Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / tensor_parallel / style.py: 90%
266 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Parallel styles for declarative tensor-parallel module sharding.
17Provides :class:`ParallelStyle` (ABC) and concrete implementations
18:class:`ColwiseParallel`, :class:`RowwiseParallel`, :class:`SequenceParallel`,
19:class:`PrepareModuleInput`, :class:`PrepareModuleInputOutput`, and
20:class:`PrepareModuleOutput` aligned with ``torch.distributed.tensor.parallel.style``.
21"""
22from abc import ABC, abstractmethod
23from typing import Any, Dict, Optional, Tuple, Union
25from hyper_parallel.core.dtensor.device_mesh import DeviceMesh
26from hyper_parallel.core.dtensor.dtensor import (
27 DTensor,
28 distribute_module,
29 distribute_tensor,
30 _distribute_module_iter_params,
31 _distribute_module_new_parameter,
32 _distribute_module_param_source,
33 _distribute_module_set_param,
34)
35from hyper_parallel.core.dtensor.placement_types import Partial, Placement, Replicate, Shard
36from hyper_parallel.platform import get_platform
38platform = get_platform()
39Module = platform.Module
41__all__ = [
42 "ParallelStyle",
43 "ColwiseParallel",
44 "RowwiseParallel",
45 "SequenceParallel",
46 "PrepareModuleInput",
47 "PrepareModuleInputOutput",
48 "PrepareModuleOutput",
49]
52class ParallelStyle(ABC):
53 """Abstract base class for parallel styles applied to nn.Module submodules.
55 Subclasses implement ``apply`` to wrap a module with the desired
56 parallel communication behaviour (e.g. all-to-all for context parallel).
58 ``src_data_rank`` mirrors PyTorch's tensor-parallel contract: it can be set by
59 :func:`parallelize_module` for styles that scatter/broadcast global tensors.
60 HyperParallel styles may ignore it until they integrate ``distribute_tensor``.
61 """
63 src_data_rank: Optional[int] = 0
65 @abstractmethod
66 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
67 """Apply this parallel style to *module* in-place and return it.
69 Args:
70 module: The submodule to be parallelised.
71 device_mesh: The device mesh describing the cluster topology.
73 Returns:
74 The (possibly wrapped) module with parallelism applied.
75 """
78class ColwiseParallel(ParallelStyle):
79 """Partition a compatible module in a column-wise fashion.
81 Currently supports Linear and Embedding modules (framework-agnostic via
82 ``platform.is_linear_module`` / ``platform.is_embedding_module``).
83 Compose with :class:`RowwiseParallel` to shard MLP or Attention blocks.
85 Keyword Args:
86 input_layouts (Placement, optional):
87 DTensor layout for the module input. Used to annotate the input
88 tensor as a DTensor. Defaults to ``Replicate()``.
89 output_layouts (Placement, optional):
90 Desired DTensor layout of the module output. Defaults to
91 ``Shard(-1)`` (sharded on the last dimension).
92 use_local_output (bool, optional):
93 If ``True`` (default), convert the output DTensor back to a local
94 tensor via ``to_local()``.
96 Returns:
97 A :class:`ParallelStyle` that applies column-wise sharding.
99 Example::
101 >>> from hyper_parallel import parallelize_module, ColwiseParallel, init_device_mesh
102 >>> m = Model(...)
103 >>> tp_mesh = init_device_mesh("npu", (8,), mesh_dim_names=("tp",))
104 >>> parallelize_module(m, tp_mesh, {"linear1": ColwiseParallel()})
105 """
107 def __init__(
108 self,
109 *,
110 input_layouts: Optional[Placement] = None,
111 output_layouts: Optional[Placement] = None,
112 use_local_output: bool = True,
113 ) -> None:
114 super().__init__()
115 self.input_layouts: Tuple[Placement, ...] = (input_layouts or Replicate(),)
116 self.output_layouts: Tuple[Placement, ...] = (output_layouts or Shard(-1),)
117 self.desired_input_layouts: Tuple[Placement, ...] = (Replicate(),)
118 self.use_local_output = use_local_output
120 def __repr__(self) -> str:
121 return (
122 f"{self.__class__.__name__}("
123 f"input_layouts={self.input_layouts}, "
124 f"output_layouts={self.output_layouts}, "
125 f"use_local_output={self.use_local_output})"
126 )
128 @staticmethod
129 def _prepare_input_fn(
130 input_layouts: Tuple[Placement, ...],
131 desired_input_layouts: Tuple[Placement, ...],
132 inputs: Any,
133 device_mesh: DeviceMesh,
134 ) -> Any:
135 """Annotate or redistribute the first positional input."""
136 input_tensor = inputs[0]
137 if not isinstance(input_tensor, DTensor):
138 input_tensor = DTensor.from_local(
139 input_tensor, device_mesh, input_layouts,
140 )
142 if input_layouts != desired_input_layouts:
143 input_tensor = input_tensor.redistribute(
144 device_mesh, desired_input_layouts,
145 )
146 return input_tensor
148 def _partition_linear_fn(self, module: Any, device_mesh: DeviceMesh) -> None:
149 """Shard Linear weight/bias along ``Shard(0)`` (column-wise)."""
150 for key, param in _distribute_module_iter_params(module):
151 if param is None:
152 continue
153 src = _distribute_module_param_source(param)
154 requires_grad = bool(getattr(param, "requires_grad", True))
155 dt = distribute_tensor(src, device_mesh, [Shard(0)])
156 new_param = _distribute_module_new_parameter(key, dt, requires_grad)
157 _distribute_module_set_param(module, key, new_param)
159 def _partition_embedding_fn(self, module: Any, device_mesh: DeviceMesh) -> None:
160 """Shard Embedding weight along ``Shard(1)`` (column-wise)."""
161 for key, param in _distribute_module_iter_params(module):
162 if param is None:
163 continue
164 src = _distribute_module_param_source(param)
165 requires_grad = bool(getattr(param, "requires_grad", True))
166 dt = distribute_tensor(src, device_mesh, [Shard(1)])
167 new_param = _distribute_module_new_parameter(key, dt, requires_grad)
168 _distribute_module_set_param(module, key, new_param)
170 @staticmethod
171 def _prepare_output_fn(
172 output_layouts: Tuple[Placement, ...],
173 use_local_output: bool,
174 outputs: Any,
175 device_mesh: DeviceMesh,
176 ) -> Any:
177 """Redistribute output to desired layout and optionally convert to local."""
178 if outputs.placements != output_layouts:
179 outputs = outputs.redistribute(device_mesh, output_layouts)
180 if use_local_output:
181 return outputs.to_local()
182 return outputs
184 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
185 """Apply column-wise parallelism to *module*.
187 Args:
188 module: A Linear or Embedding module to be sharded.
189 device_mesh: 1-D device mesh for tensor parallelism.
191 Returns:
192 The module with distributed parameters and I/O hooks attached.
194 Raises:
195 NotImplementedError: If *module* is not a supported type.
196 """
197 if platform.is_linear_module(module):
199 def partition_fn(submodule_path, submodule, device_mesh):
200 self._partition_linear_fn(submodule, device_mesh)
202 elif platform.is_embedding_module(module):
204 def partition_fn(submodule_path, submodule, device_mesh):
205 self._partition_embedding_fn(submodule, device_mesh)
207 else:
208 raise NotImplementedError(
209 "ColwiseParallel currently only supports Linear and Embedding modules!"
210 )
212 def input_fn(forward_module, forward_inputs, device_mesh):
213 return self._prepare_input_fn(
214 self.input_layouts,
215 self.desired_input_layouts,
216 forward_inputs,
217 device_mesh,
218 )
220 def output_fn(forward_module, forward_outputs, device_mesh):
221 return self._prepare_output_fn(
222 self.output_layouts,
223 self.use_local_output,
224 forward_outputs,
225 device_mesh,
226 )
228 return distribute_module(
229 module,
230 device_mesh,
231 partition_fn,
232 input_fn,
233 output_fn,
234 )
237class RowwiseParallel(ParallelStyle):
238 """Partition a compatible module in a row-wise fashion.
240 Currently supports Linear and Embedding modules (framework-agnostic via
241 ``platform.is_linear_module`` / ``platform.is_embedding_module``).
242 Compose with :class:`ColwiseParallel` to shard MLP or Attention blocks.
244 Keyword Args:
245 input_layouts (Placement, optional):
246 DTensor layout for the module input. Defaults to ``Shard(-1)``
247 (sharded on the last dimension).
248 output_layouts (Placement, optional):
249 Desired DTensor layout of the module output. Defaults to
250 ``Replicate()`` (all-reduce / reduce-scatter from partial).
251 use_local_output (bool, optional):
252 If ``True`` (default), convert the output DTensor back to a local
253 tensor via ``to_local()``.
255 Returns:
256 A :class:`ParallelStyle` that applies row-wise sharding.
258 Example::
259 >>> from hyper_parallel import parallelize_module, RowwiseParallel, init_device_mesh
260 >>> m = Model(...)
261 >>> tp_mesh = init_device_mesh("npu", (8,), mesh_dim_names=("tp",))
262 >>> parallelize_module(m, tp_mesh, {"linear2": RowwiseParallel()})
263 """
265 def __init__(
266 self,
267 *,
268 input_layouts: Optional[Placement] = None,
269 output_layouts: Optional[Placement] = None,
270 use_local_output: bool = True,
271 ) -> None:
272 super().__init__()
273 self.input_layouts: Tuple[Placement, ...] = (input_layouts or Shard(-1),)
274 self.output_layouts: Tuple[Placement, ...] = (output_layouts or Replicate(),)
275 self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),)
276 self.use_local_output = use_local_output
278 def __repr__(self) -> str:
279 return (
280 f"{self.__class__.__name__}("
281 f"input_layouts={self.input_layouts}, "
282 f"output_layouts={self.output_layouts}, "
283 f"use_local_output={self.use_local_output})"
284 )
286 @staticmethod
287 def _prepare_input_fn(
288 input_layouts: Tuple[Placement, ...],
289 desired_input_layouts: Tuple[Placement, ...],
290 inputs: Any,
291 device_mesh: DeviceMesh,
292 ) -> Any:
293 """Annotate or redistribute the first positional input."""
294 input_tensor = inputs[0]
295 if not isinstance(input_tensor, DTensor):
296 input_tensor = DTensor.from_local(
297 input_tensor, device_mesh, input_layouts,
298 )
300 if input_layouts != desired_input_layouts:
301 input_tensor = input_tensor.redistribute(
302 device_mesh, desired_input_layouts,
303 )
304 return input_tensor
306 def _partition_linear_fn(self, module: Any, device_mesh: DeviceMesh) -> None:
307 """Shard Linear weight along ``Shard(1)`` (row-wise); bias to ``Replicate()``."""
308 for key, param in _distribute_module_iter_params(module):
309 if param is None:
310 continue
311 src = _distribute_module_param_source(param)
312 requires_grad = bool(getattr(param, "requires_grad", True))
313 placement = [Shard(1)] if key == "weight" else [Replicate()]
314 dt = distribute_tensor(src, device_mesh, placement)
315 new_param = _distribute_module_new_parameter(key, dt, requires_grad)
316 _distribute_module_set_param(module, key, new_param)
318 def _partition_embedding_fn(self, module: Any, device_mesh: DeviceMesh) -> None:
319 """Shard Embedding weight along ``Shard(0)`` (row-wise)."""
320 for key, param in _distribute_module_iter_params(module):
321 if param is None:
322 continue
323 src = _distribute_module_param_source(param)
324 requires_grad = bool(getattr(param, "requires_grad", True))
325 dt = distribute_tensor(src, device_mesh, [Shard(0)])
326 new_param = _distribute_module_new_parameter(key, dt, requires_grad)
327 _distribute_module_set_param(module, key, new_param)
329 @staticmethod
330 def _prepare_output_fn(
331 output_layouts: Tuple[Placement, ...],
332 use_local_output: bool,
333 outputs: Any,
334 device_mesh: DeviceMesh,
335 module: Optional[Module] = None,
336 ) -> Any:
337 """Redistribute partial output and optionally convert to local."""
338 if not isinstance(outputs, DTensor):
339 # ``nn.Embedding.forward`` returns a plain tensor even when weight is sharded;
340 # treat the local values as partial along the TP mesh (sum) before redistributing.
341 if module is not None and platform.is_embedding_module(module):
342 outputs = DTensor.from_local(outputs, device_mesh, [Partial("sum")])
343 else:
344 raise TypeError(
345 "RowwiseParallel expects a DTensor from Linear outputs; "
346 f"got {type(outputs)}. If this is an unsupported module, extend I/O hooks."
347 )
348 if tuple(outputs.placements) != tuple(output_layouts):
349 outputs = outputs.redistribute(device_mesh, output_layouts)
350 if use_local_output:
351 return outputs.to_local()
352 return outputs
354 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
355 """Apply row-wise parallelism to *module*.
357 Args:
358 module: A Linear or Embedding module to be sharded.
359 device_mesh: 1-D device mesh for tensor parallelism.
361 Returns:
362 The module with distributed parameters and I/O hooks attached.
364 Raises:
365 NotImplementedError: If *module* is not a supported type.
366 """
367 if platform.is_linear_module(module):
369 def partition_fn(submodule_path, submodule, device_mesh):
370 self._partition_linear_fn(submodule, device_mesh)
372 self.desired_input_layouts = (Shard(-1),)
373 elif platform.is_embedding_module(module):
375 def partition_fn(submodule_path, submodule, device_mesh):
376 self._partition_embedding_fn(submodule, device_mesh)
378 self.desired_input_layouts = (Replicate(),)
379 else:
380 raise NotImplementedError(
381 "RowwiseParallel currently only supports Linear and Embedding modules!"
382 )
384 def input_fn(forward_module, forward_inputs, device_mesh):
385 return self._prepare_input_fn(
386 self.input_layouts,
387 self.desired_input_layouts,
388 forward_inputs,
389 device_mesh,
390 )
392 def output_fn(forward_module, forward_outputs, device_mesh):
393 return self._prepare_output_fn(
394 self.output_layouts,
395 self.use_local_output,
396 forward_outputs,
397 device_mesh,
398 forward_module,
399 )
401 return distribute_module(
402 module,
403 device_mesh,
404 partition_fn,
405 input_fn,
406 output_fn,
407 )
410class SequenceParallel(ParallelStyle):
411 """Replicate module parameters and run forward with the sequence axis sharded.
413 Matches ``torch.distributed.tensor.parallel.SequenceParallel``: activations are
414 sharded on the sequence dimension while weights stay fully replicated. Typical
415 targets are normalization and dropout layers used after row-wise / scatter
416 projections in tensor-parallel transformers (`Reducing Activation Recomputation
417 in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__).
419 If the first positional input is a plain tensor, it is treated as the local
420 shard along ``sequence_dim`` and wrapped as a :class:`DTensor`. If it is already
421 a :class:`DTensor` but not sharded on that dimension, it is redistributed.
423 Keyword Args:
424 sequence_dim (int, optional):
425 Tensor dimension index for the sequence axis (e.g. ``1`` for ``(B, S, H)``).
426 Default: ``1``.
427 use_local_output (bool, optional):
428 If ``True``, return a local tensor via ``to_local()``; otherwise keep a
429 :class:`DTensor`. Default: ``False`` (PyTorch default).
431 Note:
432 Like PyTorch, this assumes sensible defaults for norm weights (e.g. ones).
433 Custom initializations should be broadcast so every rank agrees before or
434 after parallelization.
436 Example::
438 >>> from hyper_parallel import parallelize_module, SequenceParallel, init_device_mesh
439 >>> m = Model(...)
440 >>> tp_mesh = init_device_mesh("npu", (8,), mesh_dim_names=("tp",))
441 >>> parallelize_module(m, tp_mesh, {"norm": SequenceParallel()})
442 """
444 def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False) -> None:
445 super().__init__()
446 self.sequence_sharding: Tuple[Placement, ...] = (Shard(sequence_dim),)
447 self.use_local_output = use_local_output
449 def __repr__(self) -> str:
450 dim = self.sequence_sharding[0].dim
451 return (
452 f"{self.__class__.__name__}("
453 f"sequence_dim={dim}, "
454 f"use_local_output={self.use_local_output})"
455 )
457 @staticmethod
458 def _prepare_input_fn(
459 sequence_sharding: Tuple[Placement, ...],
460 mod: Module,
461 inputs: Any,
462 device_mesh: DeviceMesh,
463 ) -> Any:
464 """Ensure the first input is a :class:`DTensor` sharded on the sequence dim."""
465 input_tensor = inputs[0]
466 if isinstance(input_tensor, DTensor):
467 if tuple(input_tensor.placements) != tuple(sequence_sharding):
468 input_tensor = input_tensor.redistribute(device_mesh, sequence_sharding)
469 return input_tensor
470 if platform.is_tensor(input_tensor):
471 return DTensor.from_local(input_tensor, device_mesh, sequence_sharding)
472 raise ValueError(
473 f"expecting input of {mod} to be a tensor or DTensor, but got {type(input_tensor)}"
474 )
476 @staticmethod
477 def _prepare_output_fn(use_local_output: bool, outputs: Any) -> Any:
478 if use_local_output:
479 return outputs.to_local()
480 return outputs
482 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
483 """Apply sequence-parallel hooks and replicate parameters via ``distribute_module``.
485 Args:
486 module: Submodule to parallelize (for example ``LayerNorm`` or ``Dropout``).
487 device_mesh: One-dimensional tensor-parallel device mesh.
489 Returns:
490 The same ``module`` instance with forward hooks attached and parameters
491 converted to replicated DTensors where applicable.
492 """
494 def partition_fn(_submodule_path, _submodule, _mesh):
495 return None
497 def input_fn(forward_module, forward_inputs, mesh):
498 return self._prepare_input_fn(
499 self.sequence_sharding,
500 forward_module,
501 forward_inputs,
502 mesh,
503 )
505 def output_fn(_forward_module, forward_outputs, _mesh):
506 return self._prepare_output_fn(self.use_local_output, forward_outputs)
508 return distribute_module(
509 module,
510 device_mesh,
511 partition_fn,
512 input_fn,
513 output_fn,
514 )
517class PrepareModuleInput(ParallelStyle):
518 """Prepare module forward *args* (and optional *kwargs*) as :class:`DTensor` layouts.
520 At forward time, converts each annotated positional (or keyword) tensor from local
521 to :class:`DTensor` using ``input_layouts``, then redistributes to
522 ``desired_input_layouts`` when they differ. ``None`` in a layout tuple means
523 “leave this input unchanged”.
525 Mirrors ``torch.distributed.tensor.parallel.style.PrepareModuleInput``.
527 Keyword Args:
528 input_layouts: Placements per positional arg, or a single :class:`Placement`
529 wrapped as a one-tuple. ``None`` entries skip conversion for that arg.
530 desired_input_layouts: Target placements; must match ``input_layouts`` length.
531 input_kwarg_layouts: Optional mapping kwarg name → placement for conversion.
532 desired_input_kwarg_layouts: Target placements for those kwargs (same keys).
533 use_local_output: If ``True``, convert prepared inputs back to local tensors
534 before the module runs (PyTorch names this flag ``use_local_output`` on
535 :class:`PrepareModuleInput`).
536 """
538 def __init__(
539 self,
540 *,
541 input_layouts: Optional[Union[Placement, Tuple[Optional[Placement], ...]]] = None,
542 desired_input_layouts: Optional[
543 Union[Placement, Tuple[Optional[Placement], ...]]
544 ] = None,
545 input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
546 desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
547 use_local_output: bool = False,
548 ) -> None:
549 super().__init__()
550 self.input_layouts = (
551 (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts
552 )
553 self.desired_input_layouts = (
554 (desired_input_layouts,)
555 if isinstance(desired_input_layouts, Placement)
556 else desired_input_layouts
557 )
558 self.use_local_output = use_local_output
559 if self.input_layouts is not None:
560 if self.desired_input_layouts is None:
561 raise AssertionError("desired module inputs should not be None!")
562 if len(self.input_layouts) != len(self.desired_input_layouts):
563 raise AssertionError(
564 "input_layouts and desired_input_layouts should have same length!"
565 )
566 self.with_kwargs = input_kwarg_layouts is not None
567 self.input_kwarg_layouts = input_kwarg_layouts or {}
568 self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {}
569 if self.with_kwargs:
570 if len(self.input_kwarg_layouts) != len(self.desired_input_kwarg_layouts):
571 raise AssertionError(
572 "input_kwarg_layouts and desired_input_kwarg_layouts should have "
573 "same length!"
574 )
576 def _prepare_input_arg(
577 self,
578 input_obj: Any,
579 mesh: DeviceMesh,
580 input_layout: Optional[Placement],
581 desired_layout: Optional[Placement],
582 ) -> Any:
583 """Convert one input to DTensor, redistribute if needed, optionally to_local."""
584 if input_layout is not None:
585 if isinstance(input_obj, DTensor):
586 dt_inp = input_obj
587 else:
588 if not platform.is_tensor(input_obj):
589 raise AssertionError("expecting input to be a framework tensor!")
590 dt_inp = DTensor.from_local(input_obj, mesh, (input_layout,))
592 if desired_layout is not None and input_layout != desired_layout:
593 dt_inp = dt_inp.redistribute(mesh, (desired_layout,))
595 return dt_inp.to_local() if self.use_local_output else dt_inp
596 return input_obj
598 def _prepare_input_fn(self, inputs: Any, device_mesh: DeviceMesh) -> Any:
599 """Prepare positional ``inputs`` tuple per ``input_layouts`` / ``desired_input_layouts``."""
600 if self.input_layouts is None:
601 return inputs
602 if not isinstance(inputs, tuple):
603 inputs = (inputs,)
604 if len(inputs) != len(self.input_layouts):
605 raise ValueError("module inputs and input_layouts should have same length!")
606 if self.desired_input_layouts is None:
607 raise AssertionError("desired module inputs should not be None!")
608 prepared_inputs = [
609 self._prepare_input_arg(inp, device_mesh, il, dl)
610 for inp, il, dl in zip(inputs, self.input_layouts, self.desired_input_layouts)
611 ]
612 return tuple(prepared_inputs)
614 def _prepare_input_kwarg_fn(
615 self,
616 inputs: Any,
617 kwarg_inputs: Dict[str, Any],
618 device_mesh: DeviceMesh,
619 ) -> Tuple[Any, Dict[str, Any]]:
620 """Prepare positional and keyword tensor inputs; returns ``(args, kwargs)`` for the hook."""
621 prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh)
622 prepared_kwarg_inputs: Dict[str, Any] = {}
623 for kwarg_key in kwarg_inputs:
624 kwarg_val = kwarg_inputs[kwarg_key]
625 input_layout = self.input_kwarg_layouts.get(kwarg_key)
626 desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
627 prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(
628 kwarg_val, device_mesh, input_layout, desired_input_layout
629 )
630 return (prepared_arg_inputs, prepared_kwarg_inputs)
632 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
633 if self.with_kwargs:
635 def _pre_hook(_mod, inputs, kwargs):
636 return self._prepare_input_kwarg_fn(inputs, kwargs, device_mesh)
638 platform.register_forward_pre_hook(
639 module, _pre_hook, prepend=False, with_kwargs=True,
640 )
641 else:
643 def _pre_hook(_mod, inputs):
644 return self._prepare_input_fn(inputs, device_mesh)
646 platform.register_forward_pre_hook(module, _pre_hook, prepend=False)
647 return module
649 def __repr__(self) -> str:
650 return (
651 f"{self.__class__.__name__}("
652 f"input_layouts={self.input_layouts}, "
653 f"desired_input_layouts={self.desired_input_layouts}, "
654 f"input_kwarg_layouts={self.input_kwarg_layouts}, "
655 f"desired_input_kwarg_layouts={self.desired_input_kwarg_layouts}, "
656 f"use_local_output={self.use_local_output})"
657 )
660class PrepareModuleOutput(ParallelStyle):
661 """Prepare module forward outputs as :class:`DTensor` and redistribute layouts.
663 Registers a forward hook that treats each return value like
664 ``torch.distributed.tensor.parallel.style.PrepareModuleOutput``: optional
665 ``None`` slots in ``output_layouts`` pass that output through unchanged.
667 Keyword Args:
668 output_layouts: Current or assumed placement per output tensor.
669 desired_output_layouts: Target placements; length must match ``output_layouts``.
670 use_local_output: If ``True`` (default), return local shards after redistribution.
671 """
673 def __init__(
674 self,
675 *,
676 output_layouts: Union[Placement, Tuple[Optional[Placement], ...]],
677 desired_output_layouts: Union[Placement, Tuple[Optional[Placement], ...]],
678 use_local_output: bool = True,
679 ) -> None:
680 super().__init__()
681 self.output_layouts = (
682 (output_layouts,) if isinstance(output_layouts, Placement) else output_layouts
683 )
684 self.desired_output_layouts = (
685 (desired_output_layouts,)
686 if isinstance(desired_output_layouts, Placement)
687 else desired_output_layouts
688 )
689 self.use_local_output = use_local_output
690 if len(self.output_layouts) != len(self.desired_output_layouts):
691 raise AssertionError(
692 "output_layouts and desired_output_layouts should have same length!"
693 )
695 def _prepare_out_fn(self, outputs: Any, device_mesh: DeviceMesh) -> Any:
696 """Redistribute each output tensor per ``output_layouts`` / ``desired_output_layouts``."""
697 prepared_outputs: list = []
698 if not isinstance(outputs, tuple):
699 outputs = (outputs,)
700 if len(outputs) != len(self.output_layouts):
701 raise ValueError("module outputs and output_layouts should have same length!")
702 for out, out_layout, desired_out_layout in zip(
703 outputs, self.output_layouts, self.desired_output_layouts,
704 ):
705 if out_layout is not None:
706 if isinstance(out, DTensor):
707 dt_out = out
708 else:
709 dt_out = DTensor.from_local(out, device_mesh, (out_layout,))
710 if out_layout != desired_out_layout:
711 dt_out = dt_out.redistribute(device_mesh, (desired_out_layout,))
712 prepared_outputs.append(
713 dt_out.to_local() if self.use_local_output else dt_out
714 )
715 else:
716 prepared_outputs.append(out)
717 if len(prepared_outputs) == 1:
718 return prepared_outputs[0]
719 return tuple(prepared_outputs)
721 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
723 def _hook(_mod, _inputs, outputs):
724 return self._prepare_out_fn(outputs, device_mesh)
726 module.register_forward_hook(_hook)
727 return module
729 def __repr__(self) -> str:
730 return (
731 f"{self.__class__.__name__}("
732 f"output_layouts={self.output_layouts}, "
733 f"desired_output_layouts={self.desired_output_layouts}, "
734 f"use_local_output={self.use_local_output})"
735 )
738class PrepareModuleInputOutput(ParallelStyle):
739 """Combine :class:`PrepareModuleInput` and :class:`PrepareModuleOutput` on one module.
741 Same keyword arguments as the two styles, with ``use_local_input`` mapping to
742 ``PrepareModuleInput(..., use_local_output=use_local_input)`` for PyTorch parity.
743 """
745 def __init__(
746 self,
747 *,
748 input_layouts: Optional[Union[Placement, Tuple[Optional[Placement], ...]]] = None,
749 desired_input_layouts: Optional[
750 Union[Placement, Tuple[Optional[Placement], ...]]
751 ] = None,
752 input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
753 desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
754 use_local_input: bool = False,
755 output_layouts: Union[Placement, Tuple[Optional[Placement], ...]],
756 desired_output_layouts: Union[Placement, Tuple[Optional[Placement], ...]],
757 use_local_output: bool = True,
758 ) -> None:
759 super().__init__()
760 self.prepare_module_input = PrepareModuleInput(
761 input_layouts=input_layouts,
762 desired_input_layouts=desired_input_layouts,
763 input_kwarg_layouts=input_kwarg_layouts,
764 desired_input_kwarg_layouts=desired_input_kwarg_layouts,
765 use_local_output=use_local_input,
766 )
767 self.prepare_module_output = PrepareModuleOutput(
768 output_layouts=output_layouts,
769 desired_output_layouts=desired_output_layouts,
770 use_local_output=use_local_output,
771 )
773 def apply(self, module: Module, device_mesh: DeviceMesh) -> Module:
774 self.prepare_module_input.apply(module, device_mesh)
775 self.prepare_module_output.apply(module, device_mesh)
776 return module
778 def __repr__(self) -> str:
779 p_in = self.prepare_module_input
780 p_out = self.prepare_module_output
781 return (
782 f"{self.__class__.__name__}("
783 f"input_layouts={p_in.input_layouts}, "
784 f"desired_input_layouts={p_in.desired_input_layouts}, "
785 f"input_kwarg_layouts={p_in.input_kwarg_layouts}, "
786 f"desired_input_kwarg_layouts={p_in.desired_input_kwarg_layouts}, "
787 f"use_local_input={p_in.use_local_output}, "
788 f"output_layouts={p_out.output_layouts}, "
789 f"desired_output_layouts={p_out.desired_output_layouts}, "
790 f"use_local_output={p_out.use_local_output})"
791 )