Coverage for hyper_parallel / core / shard / ops / parallel_reduce.py: 85%
155 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Reduce operator.
17"""
19from copy import deepcopy
20from typing import Sequence, Union, Tuple, List
21from hyper_parallel.core.layout import Layout
22from hyper_parallel.platform import get_platform
23from .parallel_ops import DistributedOp
24platform = get_platform()
25Tensor = platform.Tensor
28StrOrTuple = Union[str, Tuple["StrOrTuple", ...], List["StrOrTuple"]]
31class ReduceExtDistributedOpBase(DistributedOp):
32 """
33 Base class for distributed reduce operators.
35 Args:
36 op_name (str): Name of the operator to register.
37 partial_type (list): List of the operator for allreduce.
38 """
40 def __init__(self, op_name, partial_type=None):
41 super().__init__(op_name)
42 if partial_type is None:
43 partial_type = ["sum"]
44 self.partial_type = partial_type
46 def infer_layout(self, layouts, extra_args):
47 """
48 Infer output layout for reduce operator.
50 Args:
51 layouts (tuple): Layouts of input tensor.
52 extra_args (dict): Additional arguments (dim, keepdim).
54 Returns:
55 tuple: Layout for output tensor.
56 """
57 if not layouts:
58 raise ValueError(f"{self.__class__.__name__} requires at least one input layout")
60 x_layout = layouts[0]
62 if x_layout.mesh_shape is None:
63 raise ValueError("Input layouts cannot be None.")
65 # [dim, keepdim]
66 if not extra_args:
67 dim = None
68 keepdim = False
69 elif len(extra_args) == 1:
70 dim = None
71 keepdim = extra_args[0]
72 else:
73 dim, keepdim = extra_args
75 if isinstance(dim, Tensor):
76 raise TypeError(
77 "The `dim` argument should not be a `Tensor`. Instead, use one of the following types: "
78 "`None`, `int`, `tuple[int]`, or `list[int]`."
79 )
81 # Infer the output shape based on dim and keepdim
82 output_layout = self._infer_output_layout(x_layout, dim, keepdim)
84 return output_layout
86 def _infer_output_layout(self, x_layout, dim, keepdim):
87 """Infer output layout for reduce operator."""
88 # Case 1: Handle dim as an empty tuple, meaning reduce all dimensions
89 if dim is None:
90 return self._handle_all_axis_reduce(x_layout, keepdim)
92 # Case 2: Handle dim as int, tuple, or list, with keepdim True or False
93 output_layout = Layout(
94 mesh_shape=x_layout.mesh_shape,
95 alias_name=x_layout.alias_name,
96 rank_list=x_layout.rank_list
97 )
98 x_map = x_layout.alias_tensor_map
99 reduce_alias, x_map = self.replace_axis_with_none(dim, x_layout, keepdim)
100 output_layout = output_layout(*x_map)
101 self._apply_partial(output_layout, reduce_alias)
102 return output_layout
104 def _handle_all_axis_reduce(self, x_layout, keepdim):
105 """Handle the case where dim is empty, meaning reduce all dimensions."""
106 layout = Layout(
107 mesh_shape=x_layout.mesh_shape,
108 alias_name=x_layout.alias_name,
109 rank_list=x_layout.rank_list
110 )
112 if not keepdim:
113 output_layout = layout()
114 else:
115 tensor_map = tuple(["None"] * len(x_layout.alias_tensor_map))
116 output_layout = layout(*tensor_map)
118 self._apply_partial(output_layout, x_layout.alias_tensor_map)
119 return output_layout
121 def replace_axis_with_none(self, dim, x_layout, keepdim):
122 """Replace or drop dimensions depending on keepdim."""
123 if not isinstance(dim, (tuple, list)):
124 dim = [dim]
125 else:
126 dim = list(dim)
128 rank = len(x_layout.alias_tensor_map)
129 for i, axis_id in enumerate(dim):
130 if axis_id < 0:
131 dim[i] = rank + axis_id
132 if not isinstance(axis_id, int) or dim[i] >= rank or dim[i] < 0:
133 raise ValueError(f"Invalid reduce axis index {axis_id} at position {i}.")
135 alias_tensor_map = x_layout.alias_tensor_map
136 reduce_alias = [alias_tensor_map[axis_id] for axis_id in dim if
137 alias_tensor_map[axis_id] is not None and alias_tensor_map[axis_id] != "None"]
138 reduce_alias = self._flatten_aliases(reduce_alias)
140 if keepdim:
141 return self._replace_keepdim(alias_tensor_map, reduce_alias)
142 return self._replace_dropdim(alias_tensor_map, reduce_alias, dim)
144 def _flatten_aliases(self, reduce_alias):
145 """Flatten reduce_alias into a list of atomic alias strings."""
146 flat = []
147 for alias in reduce_alias:
148 if isinstance(alias, (tuple, list)):
149 flat.extend(alias)
150 else:
151 flat.append(alias)
152 return flat
154 def _replace_keepdim(self, alias_tensor_map, reduce_alias):
155 """keepdim, replace reduce alias with 'None'."""
156 new_alias_map = []
157 for alias in alias_tensor_map:
158 if isinstance(alias, (tuple, list)):
159 new_alias = tuple("None" if item in reduce_alias else item for item in alias)
160 new_alias_map.append(new_alias)
161 else:
162 if alias in reduce_alias:
163 new_alias_map.append("None")
164 else:
165 new_alias_map.append(alias)
166 new_alias_map = self._compact_tensor_map(new_alias_map)
167 return reduce_alias, tuple(new_alias_map)
169 def _replace_dropdim(self, alias_tensor_map, reduce_alias, dim):
170 """Compress reduce dim."""
171 new_alias_map = []
172 for i, alias in enumerate(alias_tensor_map):
173 if i in dim:
174 continue
175 if isinstance(alias, (tuple, list)):
176 new_alias = tuple(item for item in alias if item not in reduce_alias)
177 if new_alias:
178 new_alias_map.append(new_alias)
179 else:
180 if alias in reduce_alias:
181 continue
182 new_alias_map.append(alias)
183 new_alias_map = self._compact_tensor_map(new_alias_map)
184 return reduce_alias, tuple(new_alias_map)
186 def _compact_tensor_map(self, alias_map: Sequence[StrOrTuple]) -> Tuple[StrOrTuple, ...]:
187 """Extend tensor map of 'None'."""
189 def _compress(elem: StrOrTuple) -> StrOrTuple:
190 if isinstance(elem, (list, tuple)):
191 compressed = tuple(_compress(e) for e in elem)
192 if len(compressed) == 1:
193 return compressed[0]
194 if all(x == 'None' for x in compressed):
195 return 'None'
196 return compressed
197 return elem
199 return tuple(_compress(elem) for elem in alias_map)
201 def _apply_partial(self, out_layout, alias):
202 """Apply all partial to given alias (string, tuple, list)."""
203 if alias == "None":
204 return
205 if isinstance(alias, (tuple, list)):
206 for elem in alias:
207 self._apply_partial(out_layout, elem)
208 else:
209 for ops in self.partial_type:
210 out_layout.set_partial_by_dev_axis(alias, ops)
213class SumExtDistributedOp(ReduceExtDistributedOpBase):
214 """Distributed implementation for SumExt operator."""
216 def __init__(self, op_name="SumExt"):
217 super().__init__(op_name, partial_type=["sum"])
220class MeanExtDistributedOp(ReduceExtDistributedOpBase):
221 """Distributed implementation for MeanExt operator."""
223 def __init__(self, op_name="MeanExt"):
224 super().__init__(op_name, partial_type=["avg"])
227class ReduceMaxDistributedOp(ReduceExtDistributedOpBase):
228 """Distributed implementation for ReduceMax operator."""
230 def __init__(self, op_name="ReduceMax"):
231 super().__init__(op_name, partial_type=["max"])
233class ProdExtDistributedOp(ReduceExtDistributedOpBase):
234 """
235 Distributed implementation for ProdExt operator (product of all elements or along a dim).
236 Compatible with torch.prod arguments.
237 """
239 def __init__(self, op_name="prod"):
240 super().__init__(op_name, partial_type=["prod"])
242class AllExtDistributedOp(ReduceExtDistributedOpBase):
243 """
244 Distributed implementation for All operator
245 Returns the cumulative sum of elements of input in the dimension dim.
246 """
248 def __init__(self, op_name="all"):
249 super().__init__(op_name, partial_type=["all"])
251class MaxDistributedOp(ReduceExtDistributedOpBase):
252 """
253 Distributed implementation for Pytorch style Max operator.
255 Supports three Pytorch behaviors:
256 1. torch.max(input) -> Global reduction (returns single Tensor)
257 2. torch.max(input, dim, keepdim=False) -> Dimension reduction (returns (values, indices))
258 3. torch.max(input, other) -> Element-wise max (returns single Tensor)
259 """
261 def __init__(self, op_name="max"):
262 super().__init__(op_name, partial_type=["max"])
264 def infer_layout(self, layouts, extra_args):
265 """
266 Infer output layouts for torch.max.
267 """
268 # Filter out None layouts (corresponding to non-tensor args like dim, keepdim)
269 valid_layouts = [l for l in layouts if l is not None]
271 if not valid_layouts:
272 raise ValueError("MaxDistributedOp requires at least one input layout")
274 # Case 1: Element-wise max (e.g., torch.max(a, b))
275 if len(valid_layouts) > 1:
276 # Element-wise max returns a single tensor, so return a single Layout object.
277 return valid_layouts[0]
279 # Case 2 & 3: Reduction max
280 x_layout = valid_layouts[0]
281 if x_layout.mesh_shape is None:
282 raise ValueError("Input layouts cannot be None.")
284 dim = None
285 keepdim = False
287 if extra_args:
288 dim = extra_args[0]
289 if len(extra_args) > 1:
290 keepdim = extra_args[1]
292 if isinstance(dim, Tensor):
293 raise TypeError(
294 "The `dim` argument should not be a `Tensor`. Instead, use one of the following types: "
295 "`None`, `int`, `tuple[int]`, or `list[int]`."
296 )
298 values_layout = self._infer_output_layout(x_layout, dim, keepdim)
300 if dim is None:
301 # torch.max(input) -> Single Tensor
302 # OpDispatcher logic:
303 # if isinstance(py_output, tuple): ...
304 # else: DTensor.from_local(py_output, output_layout.mesh, ...)
305 # So here output_layout MUST be a Layout object, not a tuple.
306 return values_layout
308 # torch.max(input, dim) -> (values, indices)
309 # OpDispatcher logic expects tuple of layouts.
310 indices_layout = deepcopy(values_layout)
311 return (values_layout, indices_layout)