Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_ops.py: 52%
33 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed operator implementation.
17"""
19from typing import Optional
21from .parallel_ops_register import register_distributed_op
24class DistributedOp:
25 """
26 Base class for distributed operator implementations.
28 This class provides default implementations for distributed operators.
29 Subclasses should override methods as needed for specific operators.
31 Args:
32 op_name (str): Name of the operator to register.
33 """
34 def __init__(self, op_name):
35 self.op_name = op_name
36 register_distributed_op(op_name, self)
37 self._allow_partial_inputs = False
39 def _check_partial_inputs(self, layouts):
40 """
41 Check if any input layout has partial status and raise an error if not allowed.
43 This method can be called by subclasses to enforce that partial inputs
44 are not supported for a particular operator. Subclasses that support
45 partial inputs should not call this method.
47 Args:
48 layouts (tuple): Layouts of input tensor.
50 Raises:
51 ValueError: If any input layout has partial status.
52 """
53 for i, layout in enumerate(layouts):
54 if layout is not None and layout.is_partial():
55 raise ValueError(
56 f"For {self.op_name}, input {i} with {layout} has Partial status which is not allowed. "
57 f"Should be without Partial status for this operation."
58 )
60 # pylint: disable=W0613
61 def preprocess(self, args: tuple, kwargs: dict) -> Optional[tuple]:
62 """
63 Unified preprocessing: parameter parsing + to_local + cache_values construction.
65 Subclasses override this to participate in the new dispatch flow.
67 Args:
68 args (tuple): Positional arguments passed to the operator call.
69 kwargs (dict): Keyword arguments passed to the operator call.
71 Returns:
72 None: Fall back to legacy dispatch (default).
73 tuple: (local_args, local_kwargs, cache_values)
74 - local_args: Local tensor positional arguments (DTensors already to_local'd).
75 - local_kwargs: Local tensor keyword arguments (DTensors already to_local'd).
76 - cache_values: Values affecting layout inference (fixed order).
77 Contains Layout objects (with compact_str) and raw values (int, bool, tuple, etc.).
78 """
79 return None
81 # pylint: disable=W0613
82 def infer_layout(self, layouts: tuple, extra_args: Optional[tuple] = None) -> Optional[tuple]:
83 """
84 Infer output layouts based on input layouts.
86 Default implementation returns the first input layout for element-wise operations.
87 Subclasses can override this method to provide custom layout inference logic.
89 Args:
90 layouts (tuple): Layouts of input tensor.
91 extra_args (list): Additional arguments (dim, keepdim, etc.).
93 Returns:
94 tuple: Layouts for output tensors.
95 """
96 # Check partial inputs
97 if not self._allow_partial_inputs:
98 self._check_partial_inputs(layouts)
100 if layouts:
101 return (layouts[0],)
102 return None
104 # pylint: disable=W0613
105 def get_expand_impl(self, func: Optional[callable], infer_result: tuple, layouts: tuple,
106 extra_args: Optional[tuple] = None) -> Optional[callable]:
107 """
108 Get expand implementation for the operator.
110 Args:
111 func (Optional[callable]): The underlying operator function.
112 infer_result (tuple): Result returned by infer_layout (output_layouts, extra_info).
113 layouts (tuple): Input layouts passed to layout inference.
114 extra_args (Optional[tuple]): Additional arguments for layout inference.
116 Returns:
117 Optional[callable]: A closure that wraps the operator call with extra logic,
118 or None if no expansion is needed.
119 """
120 return None
122 def wrap_output(self, py_output, output_layouts):
123 """Wrap local outputs into DTensors according to inferred layouts.
125 Subclasses may override this when a specific operator needs custom
126 packing semantics for certain output slots.
127 """
128 # pylint: disable=C0415
129 from hyper_parallel.core.dtensor.dtensor import DTensor
131 if isinstance(py_output, (tuple, list)):
132 if len(py_output) != len(output_layouts):
133 raise RuntimeError(
134 f"Output tuple size ({len(py_output)}) "
135 f"does not match layout tuple size ({len(output_layouts)})")
136 return tuple(
137 DTensor.from_local(item, layout.mesh, layout.alias_placements)
138 for item, layout in zip(py_output, output_layouts)
139 )
141 if isinstance(output_layouts, (tuple, list)):
142 if len(output_layouts) != 1:
143 raise RuntimeError(
144 f"Scalar output expects a single layout, but got {len(output_layouts)} layouts"
145 )
146 output_layout = output_layouts[0]
147 else:
148 output_layout = output_layouts
150 return DTensor.from_local(
151 py_output, output_layout.mesh, output_layout.alias_placements
152 )