Coverage for hyper_parallel / core / shard / ops / parallel_slice_ext.py: 33%
9 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for TopK operator.
17"""
19from .parallel_ops import DistributedOp
22class SliceExtDistributedOp(DistributedOp):
23 """Distributed implementation for SliceExt operator."""
25 def infer_layout(self, layouts, extra_args):
26 """
27 Infer output layouts for Split operator.
29 Rules:
30 1. Shared axis can not be split.
32 Args:
33 layouts (Layout): Layout of input tensor
34 extra_args (list): split size or sections, axis, input shape
36 Returns:
37 tuple: Layouts for output tensors
38 """
40 input_layout = layouts[0]
41 axis = extra_args[0]
42 # Check shared axis can not be split.
43 tensor_map = input_layout.tensor_map
44 if tensor_map[axis] != -1:
45 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}")
46 return input_layout