Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_ops.py: 52%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed operator implementation. 

17""" 

18 

19from typing import Optional 

20 

21from .parallel_ops_register import register_distributed_op 

22 

23 

24class DistributedOp: 

25 """ 

26 Base class for distributed operator implementations. 

27 

28 This class provides default implementations for distributed operators. 

29 Subclasses should override methods as needed for specific operators. 

30 

31 Args: 

32 op_name (str): Name of the operator to register. 

33 """ 

34 def __init__(self, op_name): 

35 self.op_name = op_name 

36 register_distributed_op(op_name, self) 

37 self._allow_partial_inputs = False 

38 

39 def _check_partial_inputs(self, layouts): 

40 """ 

41 Check if any input layout has partial status and raise an error if not allowed. 

42 

43 This method can be called by subclasses to enforce that partial inputs 

44 are not supported for a particular operator. Subclasses that support 

45 partial inputs should not call this method. 

46 

47 Args: 

48 layouts (tuple): Layouts of input tensor. 

49 

50 Raises: 

51 ValueError: If any input layout has partial status. 

52 """ 

53 for i, layout in enumerate(layouts): 

54 if layout is not None and layout.is_partial(): 

55 raise ValueError( 

56 f"For {self.op_name}, input {i} with {layout} has Partial status which is not allowed. " 

57 f"Should be without Partial status for this operation." 

58 ) 

59 

60 # pylint: disable=W0613 

61 def preprocess(self, args: tuple, kwargs: dict) -> Optional[tuple]: 

62 """ 

63 Unified preprocessing: parameter parsing + to_local + cache_values construction. 

64 

65 Subclasses override this to participate in the new dispatch flow. 

66 

67 Args: 

68 args (tuple): Positional arguments passed to the operator call. 

69 kwargs (dict): Keyword arguments passed to the operator call. 

70 

71 Returns: 

72 None: Fall back to legacy dispatch (default). 

73 tuple: (local_args, local_kwargs, cache_values) 

74 - local_args: Local tensor positional arguments (DTensors already to_local'd). 

75 - local_kwargs: Local tensor keyword arguments (DTensors already to_local'd). 

76 - cache_values: Values affecting layout inference (fixed order). 

77 Contains Layout objects (with compact_str) and raw values (int, bool, tuple, etc.). 

78 """ 

79 return None 

80 

81 # pylint: disable=W0613 

82 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> Optional[tuple]: 

83 """ 

84 Infer output layouts based on input layouts. 

85 

86 Default implementation returns the first input layout for element-wise operations. 

87 Subclasses can override this method to provide custom layout inference logic. 

88 

89 Args: 

90 layouts (tuple): Layouts of input tensor. 

91 extra_args (list): Additional arguments (dim, keepdim, etc.). 

92 

93 Returns: 

94 tuple: Layouts for output tensors. 

95 """ 

96 # Check partial inputs 

97 if not self._allow_partial_inputs: 

98 self._check_partial_inputs(layouts) 

99 

100 if layouts: 

101 return (layouts[0],) 

102 return None 

103 

104 # pylint: disable=W0613 

105 def get_expand_impl(self, func: Optional[callable], infer_result: tuple, layouts: tuple, 

106 extra_args: Optional[tuple] = None) -> Optional[callable]: 

107 """ 

108 Get expand implementation for the operator. 

109 

110 Args: 

111 func (Optional[callable]): The underlying operator function. 

112 infer_result (tuple): Result returned by infer_layout (output_layouts, extra_info). 

113 layouts (tuple): Input layouts passed to layout inference. 

114 extra_args (Optional[tuple]): Additional arguments for layout inference. 

115 

116 Returns: 

117 Optional[callable]: A closure that wraps the operator call with extra logic, 

118 or None if no expansion is needed. 

119 """ 

120 return None 

121 

122 def wrap_output(self, py_output, output_layouts): 

123 """Wrap local outputs into DTensors according to inferred layouts. 

124 

125 Subclasses may override this when a specific operator needs custom 

126 packing semantics for certain output slots. 

127 """ 

128 # pylint: disable=C0415 

129 from hyper_parallel.core.dtensor.dtensor import DTensor 

130 

131 if isinstance(py_output, (tuple, list)): 

132 if len(py_output) != len(output_layouts): 

133 raise RuntimeError( 

134 f"Output tuple size ({len(py_output)}) " 

135 f"does not match layout tuple size ({len(output_layouts)})") 

136 return tuple( 

137 DTensor.from_local(item, layout.mesh, layout.alias_placements) 

138 for item, layout in zip(py_output, output_layouts) 

139 ) 

140 

141 if isinstance(output_layouts, (tuple, list)): 

142 if len(output_layouts) != 1: 

143 raise RuntimeError( 

144 f"Scalar output expects a single layout, but got {len(output_layouts)} layouts" 

145 ) 

146 output_layout = output_layouts[0] 

147 else: 

148 output_layout = output_layouts 

149 

150 return DTensor.from_local( 

151 py_output, output_layout.mesh, output_layout.alias_placements 

152 )