Coverage for hyper_parallel / core / shard / ops / parallel_ops.py: 67%

18 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed operator implementation. 

17""" 

18 

19from .parallel_ops_register import register_distributed_op 

20 

21class DistributedOp: 

22 """ 

23 Base class for distributed operator implementations. 

24 

25 This class provides default implementations for distributed operators. 

26 Subclasses should override methods as needed for specific operators. 

27 

28 Args: 

29 op_name (str): Name of the operator to register. 

30 """ 

31 def __init__(self, op_name): 

32 self.op_name = op_name 

33 register_distributed_op(op_name, self) 

34 self._allow_partial_inputs = False 

35 

36 def _check_partial_inputs(self, layouts): 

37 """ 

38 Check if any input layout has partial status and raise an error if not allowed. 

39 

40 This method can be called by subclasses to enforce that partial inputs 

41 are not supported for a particular operator. Subclasses that support 

42 partial inputs should not call this method. 

43 

44 Args: 

45 layouts (tuple): Layouts of input tensor. 

46 

47 Raises: 

48 ValueError: If any input layout has partial status. 

49 """ 

50 for i, layout in enumerate(layouts): 

51 if layout is not None and layout.is_partial(): 

52 raise ValueError( 

53 f"For {self.op_name}, input {i} with {layout} has Partial status which is not allowed. " 

54 f"Should be without Partial status for this operation." 

55 ) 

56 

57 # pylint: disable=W0613 

58 def infer_layout(self, layouts, extra_args): 

59 """ 

60 Infer output layouts based on input layouts. 

61 

62 Default implementation returns the first input layout for element-wise operations. 

63 Subclasses can override this method to provide custom layout inference logic. 

64 

65 Args: 

66 layouts (tuple): Layouts of input tensor. 

67 extra_args (dict): Additional arguments (dim, keepdim). 

68 

69 Returns: 

70 tuple: Layouts for output tensors. 

71 """ 

72 # Check partial inputs 

73 if not self._allow_partial_inputs: 

74 self._check_partial_inputs(layouts) 

75 

76 if layouts: 

77 return (layouts[0],) 

78 return None 

79 

80 def get_expand_impl(self, func, output_layout, layouts, extra_args): 

81 """ 

82 Get expand implementation for the operator 

83 """ 

84 return None