Coverage for hyper_parallel / core / shard / ops / parallel_one_hot_ext.py: 12%
169 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for OneHotExt operator.
17"""
19# pylint: disable=import-outside-toplevel
20from hyper_parallel.core.layout import Layout
21from hyper_parallel.core.placement_types import Shard, Replicate
22from hyper_parallel.platform import get_platform
23from .parallel_ops import DistributedOp
25platform = get_platform()
28class OneHotExtDistributedOp(DistributedOp):
29 """Distributed implementation for OneHotExt operator."""
31 def infer_layout(self, layouts, extra_args):
32 """
33 Infer output layout for OneHotExt.
35 Args:
36 layouts (tuple): Tuple containing input layouts.
37 extra_args (tuple): Additional arguments containing [num_classes, on_value, off_value, axis].
39 Returns:
40 Layout: Output layout with one-hot dimension inserted at specified axis.
41 """
42 if not layouts:
43 return None
45 indices_layout = layouts[0]
46 if indices_layout is None or indices_layout.mesh_shape is None:
47 raise ValueError(f"{self.op_name}: indices layout cannot be None")
49 if indices_layout.is_partial():
50 raise ValueError(
51 f"{self.op_name}: indices cannot be in partial state. "
52 f"Indices must contain complete index values for OneHot operation."
53 )
55 num_classes = self._get_num_classes(extra_args)
56 self._validate_num_classes(num_classes)
58 axis = self._get_axis(extra_args)
60 in_tensor_map = indices_layout.tensor_map
61 if not in_tensor_map:
62 raise ValueError(f"{self.op_name}: indices tensor_map is empty")
64 self._validate_multi_dim_restriction(in_tensor_map, axis, indices_layout)
65 self._validate_inputs_layouts(layouts)
67 out_tensor_map = self._infer_output_tensor_map(in_tensor_map, axis)
68 out_layout = self._create_layout_from_tensor_map(indices_layout, out_tensor_map)
70 out_placements = self._tensor_map_to_placements(indices_layout, out_tensor_map)
71 out_layout.set_placements(out_placements)
73 return out_layout
75 def get_expand_impl(self, func, output_layout, layouts, extra_args):
76 """Get expanded implementation for OneHotExt operator."""
77 import mindspore as ms
78 from mindspore import ops, Tensor
80 del output_layout
82 indices_layout = layouts[0] if layouts else None
83 if indices_layout is None:
84 return None
86 sharded_axes = self._get_sharded_axes(indices_layout)
87 if not sharded_axes:
88 return None
90 original_op = func
91 reduce_max = ops.ReduceMax(keep_dims=False)
93 def expanded_one_hot(indices, num_classes, on_value, off_value, axis):
94 self._validate_num_classes(num_classes)
95 self._validate_indices_dtype(indices)
97 if num_classes != -1:
98 return original_op(indices, num_classes, on_value, off_value, axis)
100 local_max = reduce_max(indices, ())
101 if not isinstance(local_max, Tensor):
102 local_max = Tensor(local_max, ms.int64)
104 local_max_host = int(local_max.asnumpy())
105 if local_max_host > 2147483647:
106 raise ValueError(
107 f"{self.op_name}: indices max value {local_max_host} exceeds int32 range"
108 )
110 zero_dim = local_max.ndim == 0
111 local_max_i32 = ops.cast(local_max, ms.int32)
113 if zero_dim:
114 local_max_i32 = ops.expand_dims(local_max_i32, 0)
116 global_max_i32 = local_max_i32
117 for axis_name in sharded_axes:
118 group = indices_layout.get_comm_group_by_axis(axis_name)
119 global_max_i32 = platform.differentiable_all_reduce(
120 global_max_i32, "max", group
121 )
123 if zero_dim:
124 global_max_i32 = ops.squeeze(global_max_i32, 0)
126 depth = int(global_max_i32.asnumpy()) + 1
127 return original_op(indices, depth, on_value, off_value, axis)
129 return expanded_one_hot
131 def _get_num_classes(self, extra_args):
132 """Extract num_classes from extra arguments."""
133 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 1:
134 num_classes = extra_args[0]
135 if isinstance(num_classes, int):
136 return num_classes
137 return -1
139 def _validate_num_classes(self, num_classes):
140 """Validate num_classes parameter."""
141 if not isinstance(num_classes, int):
142 raise TypeError(
143 f"{self.op_name}: num_classes must be int, but got {type(num_classes).__name__}"
144 )
145 if num_classes < -1:
146 raise ValueError(
147 f"{self.op_name}: num_classes must be >= -1, but got {num_classes}"
148 )
150 def _validate_indices_dtype(self, indices):
151 """Validate indices dtype."""
152 import mindspore as ms
154 if indices.dtype != ms.int64:
155 raise TypeError(
156 f"{self.op_name}: indices dtype must be int64, but got {indices.dtype}"
157 )
159 def _get_sharded_axes(self, layout):
160 """Get all device axes that are used for sharding."""
161 sharded_axes = set()
163 if layout is None or layout.alias_tensor_map is None:
164 return []
166 for dim_alias in layout.alias_tensor_map:
167 if dim_alias == "None":
168 continue
170 if isinstance(dim_alias, tuple):
171 for axis_name in dim_alias:
172 if axis_name != "None":
173 sharded_axes.add(axis_name)
174 else:
175 sharded_axes.add(dim_alias)
177 return list(sharded_axes)
179 def _get_axis(self, extra_args):
180 """Extract axis parameter from extra arguments."""
181 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 4:
182 axis = extra_args[3]
183 if isinstance(axis, int):
184 return self._validate_axis(axis)
185 return -1
187 def _validate_axis(self, axis):
188 """Validate axis parameter."""
189 if not isinstance(axis, int):
190 raise TypeError(
191 f"{self.op_name}: axis must be int, but got {type(axis).__name__}"
192 )
194 if axis > 1 or axis < -1:
195 raise ValueError(f"{self.op_name}: axis {axis} is out of range[-1, 1]")
197 return axis
199 def _validate_multi_dim_restriction(self, in_tensor_map, axis, indices_layout):
200 """Validate restriction for multi-dimensional inputs."""
201 in_rank = len(in_tensor_map)
202 if in_rank <= 1:
203 return
205 if axis != -1:
206 raise ValueError(
207 f"{self.op_name}: when input dimension is > 1, axis must be -1, but got {axis}"
208 )
210 alias_map = indices_layout.alias_tensor_map
211 for i in range(1, len(alias_map)):
212 if alias_map[i] != "None":
213 raise ValueError(
214 f"{self.op_name}: when input dimension is > 1, strategy must be data parallel, "
215 f"but dimension {i} is sharded on '{alias_map[i]}'"
216 )
218 def _validate_inputs_layouts(self, layouts):
219 """Validate that non-indices inputs are fully replicated."""
220 for layout in layouts[1:]:
221 if layout is None:
222 continue
223 alias_map = layout.alias_tensor_map
224 if alias_map and any(x != "None" for x in alias_map):
225 raise ValueError(
226 f"{self.op_name}: non-indices inputs must be replicated, but got {alias_map}"
227 )
229 def _infer_output_tensor_map(self, in_tensor_map, axis):
230 """Infer output tensor map by inserting one-hot dimension at specified axis."""
231 in_rank = len(in_tensor_map)
233 if axis in (-1, in_rank):
234 insert_pos = in_rank
235 else:
236 insert_pos = axis
238 if insert_pos < 0 or insert_pos > in_rank:
239 raise ValueError(
240 f"{self.op_name}: axis {axis} is out of range for input with rank {in_rank}"
241 )
243 out_tensor_map = list(in_tensor_map)
244 out_tensor_map.insert(insert_pos, -1)
245 return tuple(out_tensor_map)
247 def _create_layout_from_tensor_map(self, base_layout, out_tensor_map):
248 """Create output layout from tensor map."""
249 out_layout = Layout(
250 mesh_shape=base_layout.mesh_shape,
251 alias_name=base_layout.alias_name,
252 rank_list=base_layout.rank_list,
253 )
255 out_layout.set_tensor_map(out_tensor_map)
256 out_layout.set_alias_tensor_map(
257 self._tensor_map_to_alias_tensor_map(base_layout, out_tensor_map)
258 )
259 out_layout.update_compact_str()
260 return out_layout
262 def _tensor_map_to_alias_tensor_map(self, base_layout, tensor_map):
263 """Convert numeric tensor map to alias tensor map."""
264 alias_name = base_layout.alias_name
265 alias_tensor_map = []
267 for dim in tensor_map:
268 if dim == -1:
269 alias_tensor_map.append("None")
270 continue
272 if isinstance(dim, tuple):
273 names = tuple(
274 alias_name[len(alias_name) - 1 - d] for d in dim if d != -1
275 )
276 alias_tensor_map.append(names if names else "None")
277 continue
279 alias_tensor_map.append(alias_name[len(alias_name) - 1 - dim])
281 return tuple(alias_tensor_map)
283 def _tensor_map_to_placements(self, base_layout, tensor_map):
284 """
285 Convert tensor_map to placements.
287 Args:
288 base_layout: Base layout to get mesh dimension info
289 tensor_map: Tensor map to convert
291 Returns:
292 tuple: Placements tuple (Shard/Replicate for each mesh dimension)
293 """
294 mesh_ndim = len(base_layout.mesh_shape)
295 placements = []
297 for mesh_dim_idx in range(mesh_ndim):
298 is_sharded = False
300 for tensor_dim_idx, tensor_dim_map in enumerate(tensor_map):
301 if tensor_dim_map == -1:
302 continue
304 if isinstance(tensor_dim_map, tuple):
305 if mesh_dim_idx in tensor_dim_map:
306 placements.append(Shard(tensor_dim_idx))
307 is_sharded = True
308 break
309 elif tensor_dim_map == mesh_dim_idx:
310 placements.append(Shard(tensor_dim_idx))
311 is_sharded = True
312 break
314 if not is_sharded:
315 placements.append(Replicate())
317 return tuple(placements)