Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_histc_ext.py: 95%
37 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for HistcExt operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from hyper_parallel.platform import get_platform
21from .parallel_ops import DistributedOp
23platform = get_platform()
26class HistcExtDistributedOp(DistributedOp):
27 """
28 Distributed implementation for HistcExt operator.
30 HistcExt computes the histogram of a tensor. In distributed setting:
31 - Each device computes a local histogram
32 - Local histograms are aggregated using AllReduce(SUM)
33 - Output is always replicated (1D tensor with shape (bins,))
34 """
36 def __init__(self, op_name="HistcExt"):
37 super().__init__(op_name)
39 def infer_layout(self, layouts, extra_args):
40 """
41 Infer output layout for HistcExt operator.
43 Args:
44 layouts (tuple): Layouts of input tensor.
45 extra_args (tuple): (bins, min, max) parameters.
47 Returns:
48 Layout: Layout for output histogram tensor.
49 """
50 if not layouts or len(layouts) < 1:
51 raise ValueError(
52 f"{self.__class__.__name__} requires at least one input layout, "
53 f"got {len(layouts) if layouts else 0}"
54 )
55 x_layout = layouts[0]
56 if x_layout is None or x_layout.mesh_shape is None:
57 raise ValueError("Input layout cannot be None.")
59 bins = extra_args[0] if len(extra_args) > 0 else 100
60 min_val = extra_args[1] if len(extra_args) > 1 else 0
61 max_val = extra_args[2] if len(extra_args) > 2 else 0
63 if not isinstance(bins, int):
64 raise ValueError(f"bins must be an integer, got {type(bins)}")
65 if bins <= 0:
66 raise ValueError(f"bins must be a positive integer, got {bins}")
67 if not isinstance(min_val, (int, float)):
68 raise ValueError(f"min must be a number, got {type(min_val)}")
69 if not isinstance(max_val, (int, float)):
70 raise ValueError(f"max must be a number, got {type(max_val)}")
71 if min_val > max_val:
72 raise ValueError(f"min must be less than or equal to max, got min={min_val}, max={max_val}")
74 output_layout = Layout(
75 mesh_shape=x_layout.mesh_shape,
76 alias_name=x_layout.alias_name,
77 rank_list=x_layout.rank_list
78 )
79 output_layout.set_tensor_map((-1,)) # Output is 1D histogram with shape (bins,)
81 has_sharding = any(
82 alias is not None and alias != "None"
83 for alias in x_layout.alias_tensor_map
84 )
86 if has_sharding:
87 for alias, tensor_map_val in zip(x_layout.alias_name, x_layout.alias_tensor_map):
88 if tensor_map_val is not None and tensor_map_val != "None":
89 output_layout.set_partial_by_dev_axis(alias, "sum")
90 # pylint: disable=protected-access
91 output_layout._alias_tensor_map = output_layout._build_readable_tensor_map()
92 # pylint: disable=protected-access
93 output_layout.tensor_map_to_placement()
94 output_layout.update_compact_str()
96 return output_layout