Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_reduce.py: 86%
159 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Reduce operator.
17"""
19from copy import deepcopy
20from typing import Sequence, Union, Tuple, List
21from hyper_parallel.core.dtensor.layout import Layout
22from hyper_parallel.platform import get_platform
23from .parallel_ops import DistributedOp
24platform = get_platform()
25Tensor = platform.Tensor
28StrOrTuple = Union[str, Tuple["StrOrTuple", ...], List["StrOrTuple"]]
31class ReduceExtDistributedOpBase(DistributedOp):
32 """
33 Base class for distributed reduce operators.
35 Args:
36 op_name (str): Name of the operator to register.
37 partial_type (list): List of the operator for allreduce.
38 """
40 def __init__(self, op_name, partial_type=None):
41 super().__init__(op_name)
42 if partial_type is None:
43 partial_type = ["sum"]
44 self.partial_type = partial_type
46 def infer_layout(self, layouts, extra_args=None):
47 """
48 Infer output layout for reduce operator.
50 Args:
51 layouts (tuple): Layouts of input tensor.
52 extra_args (dict): Additional arguments (dim, keepdim).
54 Returns:
55 tuple: Layout for output tensor.
56 """
57 if not layouts:
58 raise ValueError(f"{self.__class__.__name__} requires at least one input layout")
60 x_layout = layouts[0]
62 if x_layout.mesh_shape is None:
63 raise ValueError("Input layouts cannot be None.")
65 # [dim, keepdim]
66 if not extra_args:
67 dim = None
68 keepdim = False
69 elif len(extra_args) == 1:
70 dim = None
71 keepdim = extra_args[0]
72 else:
73 dim, keepdim = extra_args
75 if isinstance(dim, Tensor):
76 raise TypeError(
77 "The `dim` argument should not be a `Tensor`. Instead, use one of the following types: "
78 "`None`, `int`, `tuple[int]`, or `list[int]`."
79 )
81 # Infer the output shape based on dim and keepdim
82 output_layout = self._infer_output_layout(x_layout, dim, keepdim)
84 return output_layout
86 def _infer_output_layout(self, x_layout, dim, keepdim):
87 """Infer output layout for reduce operator."""
88 # Case 1: Handle dim as an empty tuple, meaning reduce all dimensions
89 if dim is None:
90 return self._handle_all_axis_reduce(x_layout, keepdim)
92 # Case 2: Handle dim as int, tuple, or list, with keepdim True or False
93 output_layout = Layout(
94 mesh_shape=x_layout.mesh_shape,
95 alias_name=x_layout.alias_name,
96 rank_list=x_layout.rank_list
97 )
98 x_map = x_layout.alias_tensor_map
99 reduce_alias, x_map = self.replace_axis_with_none(dim, x_layout, keepdim)
100 output_layout = output_layout(*x_map)
101 self._apply_partial(output_layout, reduce_alias)
102 return output_layout
104 def _handle_all_axis_reduce(self, x_layout, keepdim):
105 """Handle the case where dim is empty, meaning reduce all dimensions."""
106 layout = Layout(
107 mesh_shape=x_layout.mesh_shape,
108 alias_name=x_layout.alias_name,
109 rank_list=x_layout.rank_list
110 )
112 if not keepdim:
113 output_layout = layout()
114 else:
115 tensor_map = tuple(["None"] * len(x_layout.alias_tensor_map))
116 output_layout = layout(*tensor_map)
118 self._apply_partial(output_layout, x_layout.alias_tensor_map)
119 return output_layout
121 def replace_axis_with_none(self, dim, x_layout, keepdim):
122 """Replace or drop dimensions depending on keepdim."""
123 if not isinstance(dim, (tuple, list)):
124 dim = [dim]
125 else:
126 dim = list(dim)
128 rank = len(x_layout.alias_tensor_map)
129 for i, axis_id in enumerate(dim):
130 if axis_id < 0:
131 dim[i] = rank + axis_id
132 if not isinstance(axis_id, int) or dim[i] >= rank or dim[i] < 0:
133 raise ValueError(f"Invalid reduce axis index {axis_id} at position {i}.")
135 alias_tensor_map = x_layout.alias_tensor_map
136 reduce_alias = [alias_tensor_map[axis_id] for axis_id in dim if
137 alias_tensor_map[axis_id] is not None and alias_tensor_map[axis_id] != "None"]
138 reduce_alias = self._flatten_aliases(reduce_alias)
140 if keepdim:
141 return self._replace_keepdim(alias_tensor_map, reduce_alias)
142 return self._replace_dropdim(alias_tensor_map, reduce_alias, dim)
144 def _flatten_aliases(self, reduce_alias):
145 """Flatten reduce_alias into a list of atomic alias strings."""
146 flat = []
147 for alias in reduce_alias:
148 if isinstance(alias, (tuple, list)):
149 flat.extend(alias)
150 else:
151 flat.append(alias)
152 return flat
154 def _replace_keepdim(self, alias_tensor_map, reduce_alias):
155 """keepdim, replace reduce alias with 'None'."""
156 new_alias_map = []
157 for alias in alias_tensor_map:
158 if isinstance(alias, (tuple, list)):
159 new_alias = tuple("None" if item in reduce_alias else item for item in alias)
160 new_alias_map.append(new_alias)
161 else:
162 if alias in reduce_alias:
163 new_alias_map.append("None")
164 else:
165 new_alias_map.append(alias)
166 new_alias_map = self._compact_tensor_map(new_alias_map)
167 return reduce_alias, tuple(new_alias_map)
169 def _replace_dropdim(self, alias_tensor_map, reduce_alias, dim):
170 """Compress reduce dim."""
171 new_alias_map = []
172 for i, alias in enumerate(alias_tensor_map):
173 if i in dim:
174 continue
175 if isinstance(alias, (tuple, list)):
176 new_alias = tuple(item for item in alias if item not in reduce_alias)
177 if new_alias:
178 new_alias_map.append(new_alias)
179 else:
180 if alias in reduce_alias:
181 continue
182 new_alias_map.append(alias)
183 new_alias_map = self._compact_tensor_map(new_alias_map)
184 return reduce_alias, tuple(new_alias_map)
186 def _compact_tensor_map(self, alias_map: Sequence[StrOrTuple]) -> Tuple[StrOrTuple, ...]:
187 """Extend tensor map of 'None'."""
189 def _compress(elem: StrOrTuple) -> StrOrTuple:
190 if isinstance(elem, (list, tuple)):
191 compressed = tuple(_compress(e) for e in elem)
192 if len(compressed) == 1:
193 return compressed[0]
194 if all(x == 'None' for x in compressed):
195 return 'None'
196 return compressed
197 return elem
199 return tuple(_compress(elem) for elem in alias_map)
201 def _apply_partial(self, out_layout, alias):
202 """Apply all partial to given alias (string, tuple, list)."""
203 if alias == "None":
204 return
205 if isinstance(alias, (tuple, list)):
206 for elem in alias:
207 self._apply_partial(out_layout, elem)
208 else:
209 for ops in self.partial_type:
210 out_layout.set_partial_by_dev_axis(alias, ops)
213class SumExtDistributedOp(ReduceExtDistributedOpBase):
214 """Distributed implementation for SumExt operator."""
216 def __init__(self, op_name="SumExt"):
217 super().__init__(op_name, partial_type=["sum"])
220class MeanExtDistributedOp(ReduceExtDistributedOpBase):
221 """Distributed implementation for MeanExt operator."""
223 def __init__(self, op_name="MeanExt"):
224 super().__init__(op_name, partial_type=["avg"])
227class ReduceMaxDistributedOp(ReduceExtDistributedOpBase):
228 """Distributed implementation for ReduceMax operator."""
230 def __init__(self, op_name="ReduceMax"):
231 super().__init__(op_name, partial_type=["max"])
234class ProdExtDistributedOp(ReduceExtDistributedOpBase):
235 """
236 Distributed implementation for ProdExt operator (product of all elements or along a dim).
237 Compatible with torch.prod arguments.
238 """
240 def __init__(self, op_name="prod"):
241 super().__init__(op_name, partial_type=["prod"])
244class AllExtDistributedOp(ReduceExtDistributedOpBase):
245 """
246 Distributed implementation for All operator
247 Returns the cumulative sum of elements of input in the dimension dim.
248 """
250 def __init__(self, op_name="all"):
251 super().__init__(op_name, partial_type=["all"])
254class MaxDistributedOp(ReduceExtDistributedOpBase):
255 """
256 Distributed implementation for Pytorch style Max operator.
258 Supports three Pytorch behaviors:
259 1. torch.max(input) -> Global reduction (returns single Tensor)
260 2. torch.max(input, dim, keepdim=False) -> Dimension reduction (returns (values, indices))
261 3. torch.max(input, other) -> Element-wise max (returns single Tensor)
262 """
264 def __init__(self, op_name="max"):
265 super().__init__(op_name, partial_type=["max"])
267 def infer_layout(self, layouts, extra_args=None):
268 """
269 Infer output layouts for torch.max.
270 """
271 # Filter out None layouts (corresponding to non-tensor args like dim, keepdim)
272 valid_layouts = [layout for layout in layouts if layout is not None]
274 if not valid_layouts:
275 raise ValueError("MaxDistributedOp requires at least one input layout")
277 # Case 1: Element-wise max (e.g., torch.max(a, b))
278 if len(valid_layouts) > 1:
279 # Element-wise max returns a single tensor, so return a single Layout object.
280 return valid_layouts[0]
282 # Case 2 & 3: Reduction max
283 x_layout = valid_layouts[0]
284 if x_layout.mesh_shape is None:
285 raise ValueError("Input layouts cannot be None.")
287 dim = None
288 keepdim = False
290 if extra_args:
291 dim = extra_args[0]
292 if len(extra_args) > 1:
293 keepdim = extra_args[1]
295 if isinstance(dim, (Tensor, str)):
296 raise TypeError(
297 "The `dim` argument should not be a `Tensor` or a `str`. Instead, use one of the following types: "
298 "`None`, `int`, `tuple[int]`, or `list[int]`."
299 )
301 values_layout = self._infer_output_layout(x_layout, dim, keepdim)
303 if dim is None:
304 # torch.max(input) -> Single Tensor
305 # OpDispatcher logic:
306 # if isinstance(py_output, tuple): ...
307 # else: DTensor.from_local(py_output, output_layout.mesh, ...)
308 # So here output_layout MUST be a Layout object, not a tuple.
309 return values_layout
311 # torch.max(input, dim) -> (values, indices)
312 # OpDispatcher logic expects tuple of layouts.
313 indices_layout = deepcopy(values_layout)
314 return (values_layout, indices_layout)
317class MinDistributedOp(MaxDistributedOp):
318 """
319 Distributed implementation for Pytorch style Min operator.
321 Supports three Pytorch behaviors:
322 1. torch.min(input) -> Global reduction (returns single Tensor)
323 2. torch.min(input, dim, keepdim=False) -> Dimension reduction (returns (values, indices))
324 3. torch.min(input, other) -> Element-wise min (returns single Tensor)
325 """
327 def __init__(self, op_name="min"):
328 # Call the parent class (MaxDistributedOp) initialization
329 super().__init__(op_name=op_name)
330 # Override the partial_type to use "min" instead of "max" for the underlying communication
331 self.partial_type = ["min"]