Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_histc_ext.py: 95%

37 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for HistcExt operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from hyper_parallel.platform import get_platform 

21from .parallel_ops import DistributedOp 

22 

23platform = get_platform() 

24 

25 

26class HistcExtDistributedOp(DistributedOp): 

27 """ 

28 Distributed implementation for HistcExt operator. 

29 

30 HistcExt computes the histogram of a tensor. In distributed setting: 

31 - Each device computes a local histogram 

32 - Local histograms are aggregated using AllReduce(SUM) 

33 - Output is always replicated (1D tensor with shape (bins,)) 

34 """ 

35 

36 def __init__(self, op_name="HistcExt"): 

37 super().__init__(op_name) 

38 

39 def infer_layout(self, layouts, extra_args): 

40 """ 

41 Infer output layout for HistcExt operator. 

42 

43 Args: 

44 layouts (tuple): Layouts of input tensor. 

45 extra_args (tuple): (bins, min, max) parameters. 

46 

47 Returns: 

48 Layout: Layout for output histogram tensor. 

49 """ 

50 if not layouts or len(layouts) < 1: 

51 raise ValueError( 

52 f"{self.__class__.__name__} requires at least one input layout, " 

53 f"got {len(layouts) if layouts else 0}" 

54 ) 

55 x_layout = layouts[0] 

56 if x_layout is None or x_layout.mesh_shape is None: 

57 raise ValueError("Input layout cannot be None.") 

58 

59 bins = extra_args[0] if len(extra_args) > 0 else 100 

60 min_val = extra_args[1] if len(extra_args) > 1 else 0 

61 max_val = extra_args[2] if len(extra_args) > 2 else 0 

62 

63 if not isinstance(bins, int): 

64 raise ValueError(f"bins must be an integer, got {type(bins)}") 

65 if bins <= 0: 

66 raise ValueError(f"bins must be a positive integer, got {bins}") 

67 if not isinstance(min_val, (int, float)): 

68 raise ValueError(f"min must be a number, got {type(min_val)}") 

69 if not isinstance(max_val, (int, float)): 

70 raise ValueError(f"max must be a number, got {type(max_val)}") 

71 if min_val > max_val: 

72 raise ValueError(f"min must be less than or equal to max, got min={min_val}, max={max_val}") 

73 

74 output_layout = Layout( 

75 mesh_shape=x_layout.mesh_shape, 

76 alias_name=x_layout.alias_name, 

77 rank_list=x_layout.rank_list 

78 ) 

79 output_layout.set_tensor_map((-1,)) # Output is 1D histogram with shape (bins,) 

80 

81 has_sharding = any( 

82 alias is not None and alias != "None" 

83 for alias in x_layout.alias_tensor_map 

84 ) 

85 

86 if has_sharding: 

87 for alias, tensor_map_val in zip(x_layout.alias_name, x_layout.alias_tensor_map): 

88 if tensor_map_val is not None and tensor_map_val != "None": 

89 output_layout.set_partial_by_dev_axis(alias, "sum") 

90 # pylint: disable=protected-access 

91 output_layout._alias_tensor_map = output_layout._build_readable_tensor_map() 

92 # pylint: disable=protected-access 

93 output_layout.tensor_map_to_placement() 

94 output_layout.update_compact_str() 

95 

96 return output_layout