Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / custom_shard.py: 18%
50 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""shard"""
17import queue
18from typing import Callable, Tuple, Optional
19from hyper_parallel.core.dtensor.layout import DeviceMesh
20from hyper_parallel.core.dtensor.dtensor import DTensor
21from hyper_parallel.core.dtensor.placement_types import Placement
22from hyper_parallel.platform import get_platform
23platform = get_platform()
24Tensor = platform.Tensor
27def custom_shard(
28 func: Callable,
29 device_mesh: DeviceMesh,
30 out_placements: Tuple[Tuple[Placement, ...], ...],
31 in_placements: Optional[Tuple[Optional[Tuple[Placement, ...]], ...]] = None,
32 redistribute_inputs: bool = True,
33) -> Callable:
34 """
35 Wraps a function to handle distributed tensor conversions.
37 Args:
38 func (Callable): The function to be wrapped.
39 device_mesh (DeviceMesh): The device mesh for sharding.
40 out_placements (Tuple[Tuple[Placement, ...], ...]): Placements for each output tensor.
41 in_placements (Optional[Tuple[Optional[Tuple[Placement, ...]], ...]], optional):
42 Placements for each input argument. None entries indicate non-tensor inputs.
43 redistribute_inputs (bool): Whether to redistribute inputs to required placements.
45 Returns:
46 Callable: Wrapped function that handles distributed tensors.
48 Examples:
49 >>> mesh = DeviceMesh("npu", (2, 2), mesh_dim_names=("dp", "tp"))
50 >>> @custom_shard(
51 ... device_mesh=mesh,
52 ... out_placements=((Shard(0), Replicate()),),
53 ... in_placements=((Shard(0), Replicate()), (Replicate(), Shard(1)))
54 ... )
55 ... def my_func(x, y):
56 ... return x + y
57 """
58 def wrapped(*args, **kwargs):
59 if in_placements is not None:
60 if len(in_placements) != len(args):
61 raise ValueError(
62 f"in_placements length {len(in_placements)} does not match "
63 f"the number of input args {len(args)}!"
64 )
66 local_args = []
67 contain_distributed_arg = False
69 args_layout = queue.Queue(len(args))
70 for i, arg in enumerate(args):
71 if isinstance(arg, DTensor):
72 if in_placements is None:
73 raise RuntimeError("Found Tensor input but in_placements is None")
75 required_in_placement = in_placements[i]
76 if required_in_placement is None:
77 raise TypeError(
78 f"Tensor input at position {i} requires Placement, "
79 "but corresponding in_placements entry is None!"
80 )
82 if redistribute_inputs:
83 arg = arg.redistribute(device_mesh, required_in_placement)
85 args_layout.put(arg.layout)
86 local_tensor = arg.to_local()
87 local_args.append(local_tensor)
88 contain_distributed_arg = True
90 else:
91 if in_placements is not None and in_placements[i] is not None:
92 raise TypeError(
93 f"Non-DTensor input at position {i} requires None in_placements, "
94 f"but received {in_placements[i]}!"
95 )
96 local_args.append(arg)
98 out = func(*local_args, **kwargs)
100 if not contain_distributed_arg:
101 return out
103 out_is_tuple = isinstance(out, tuple)
104 out_tuple = (out,) if not out_is_tuple else out
106 if len(out_tuple) != len(out_placements):
107 raise ValueError(
108 f"Output count {len(out_tuple)} does not match "
109 f"out_placements count {len(out_placements)}!"
110 )
112 dist_output = []
113 for item, out_placement in zip(out_tuple, out_placements):
114 if isinstance(item, Tensor):
115 if out_placement is None:
116 raise TypeError(
117 "Tensor output requires non-None out_placements!"
118 )
119 dist_output.append(
120 DTensor.from_local(item, device_mesh=device_mesh, placements=out_placement)
121 )
122 else:
123 if out_placement is not None:
124 raise TypeError(
125 f"Non-tensor output requires None out_placements, got {out_placement}!"
126 )
127 dist_output.append(item)
129 return dist_output[0] if not out_is_tuple else tuple(dist_output)
131 return wrapped