Coverage for hyper_parallel / core / shard / ops / parallel_ops.py: 67%
18 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed operator implementation.
17"""
19from .parallel_ops_register import register_distributed_op
21class DistributedOp:
22 """
23 Base class for distributed operator implementations.
25 This class provides default implementations for distributed operators.
26 Subclasses should override methods as needed for specific operators.
28 Args:
29 op_name (str): Name of the operator to register.
30 """
31 def __init__(self, op_name):
32 self.op_name = op_name
33 register_distributed_op(op_name, self)
34 self._allow_partial_inputs = False
36 def _check_partial_inputs(self, layouts):
37 """
38 Check if any input layout has partial status and raise an error if not allowed.
40 This method can be called by subclasses to enforce that partial inputs
41 are not supported for a particular operator. Subclasses that support
42 partial inputs should not call this method.
44 Args:
45 layouts (tuple): Layouts of input tensor.
47 Raises:
48 ValueError: If any input layout has partial status.
49 """
50 for i, layout in enumerate(layouts):
51 if layout is not None and layout.is_partial():
52 raise ValueError(
53 f"For {self.op_name}, input {i} with {layout} has Partial status which is not allowed. "
54 f"Should be without Partial status for this operation."
55 )
57 # pylint: disable=W0613
58 def infer_layout(self, layouts, extra_args):
59 """
60 Infer output layouts based on input layouts.
62 Default implementation returns the first input layout for element-wise operations.
63 Subclasses can override this method to provide custom layout inference logic.
65 Args:
66 layouts (tuple): Layouts of input tensor.
67 extra_args (dict): Additional arguments (dim, keepdim).
69 Returns:
70 tuple: Layouts for output tensors.
71 """
72 # Check partial inputs
73 if not self._allow_partial_inputs:
74 self._check_partial_inputs(layouts)
76 if layouts:
77 return (layouts[0],)
78 return None
80 def get_expand_impl(self, func, output_layout, layouts, extra_args):
81 """
82 Get expand implementation for the operator
83 """
84 return None