Coverage for hyper_parallel / core / shard / local_func.py: 83%
48 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""shard"""
17import queue
18from typing import Callable, Tuple, Optional
19from hyper_parallel.core.layout import DeviceMesh
20from hyper_parallel.core.dtensor import DTensor
21from hyper_parallel.core.placement_types import Placement
22from hyper_parallel.platform import get_platform
23platform = get_platform()
24Tensor = platform.Tensor
26def custom_shard(
27 func: Callable,
28 device_mesh: DeviceMesh,
29 out_placements: Tuple[Tuple[Placement, ...], ...],
30 in_placements: Optional[Tuple[Optional[Tuple[Placement, ...]], ...]] = None,
31 redistribute_inputs: bool = True,
32) -> Callable:
33 """
34 Wraps a function to handle distributed tensor conversions.
36 Args:
37 func (Callable): The function to be wrapped.
38 device_mesh (DeviceMesh): The device mesh for sharding.
39 out_placements (Tuple[Tuple[Placement, ...], ...]): Placements for each output tensor.
40 in_placements (Optional[Tuple[Optional[Tuple[Placement, ...]], ...]], optional):
41 Placements for each input argument. None entries indicate non-tensor inputs.
42 redistribute_inputs (bool): Whether to redistribute inputs to required placements.
44 Returns:
45 Callable: Wrapped function that handles distributed tensors.
47 Examples:
48 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp"))
49 >>> @custom_shard(
50 ... device_mesh=mesh,
51 ... out_placements=((Shard(0), Replicate()),),
52 ... in_placements=((Shard(0), Replicate()), (Replicate(), Shard(1)))
53 ... )
54 ... def my_func(x, y):
55 ... return x + y
56 """
57 def wrapped(*args, **kwargs):
58 if in_placements is not None:
59 assert len(in_placements) == len(args), (
60 f"in_placements length {len(in_placements)} does not match "
61 f"the number of input args {len(args)}!"
62 )
64 local_args = []
65 contain_distributed_arg = False
67 args_layout = queue.Queue(len(args))
68 for i, arg in enumerate(args):
69 if isinstance(arg, DTensor):
70 if in_placements is None:
71 raise RuntimeError("Found Tensor input but in_placements is None")
73 required_in_placement = in_placements[i]
74 if required_in_placement is None:
75 raise TypeError(
76 f"Tensor input at position {i} requires Placement, "
77 "but corresponding in_placements entry is None!"
78 )
80 if redistribute_inputs:
81 arg = arg.redistribute(device_mesh, required_in_placement)
83 args_layout.put(arg.layout)
84 local_tensor = arg.to_local()
85 local_args.append(local_tensor)
86 contain_distributed_arg = True
88 else:
89 if in_placements is not None and in_placements[i] is not None:
90 raise TypeError(
91 f"Non-DTensor input at position {i} requires None in_placements, "
92 f"but received {in_placements[i]}!"
93 )
94 local_args.append(arg)
96 out = func(*local_args, **kwargs)
98 if not contain_distributed_arg:
99 return out
101 out_is_tuple = isinstance(out, tuple)
102 out_tuple = (out,) if not out_is_tuple else out
104 assert len(out_tuple) == len(out_placements), (
105 f"Output count {len(out_tuple)} does not match "
106 f"out_placements count {len(out_placements)}!"
107 )
109 dist_output = []
110 for item, out_placement in zip(out_tuple, out_placements):
111 if isinstance(item, Tensor):
112 if out_placement is None:
113 raise TypeError(
114 "Tensor output requires non-None out_placements!"
115 )
116 dist_output.append(
117 DTensor.from_local(item, device_mesh=device_mesh, placements=out_placement)
118 )
119 else:
120 if out_placement is not None:
121 raise TypeError(
122 f"Non-tensor output requires None out_placements, got {out_placement}!"
123 )
124 dist_output.append(item)
126 return dist_output[0] if not out_is_tuple else tuple(dist_output)
128 return wrapped