Coverage for hyper_parallel / core / shard / ops / parallel_slice.py: 15%
34 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Slice operator.
17"""
18# pylint: disable=E0402
19from .parallel_ops import DistributedOp
22class SliceDistributedOp(DistributedOp):
23 """Distributed implementation for MatMul operator."""
25 def _is_shard_dim(self, layout):
26 """return the shard num in each dim"""
27 shard_dim = []
28 dev_mat = layout.mesh_shape
29 tensor_map = layout.tensor_map
30 for dev_idx in tensor_map:
31 if isinstance(dev_idx, (tuple, list)):
32 shard_num = 1
33 for idx in dev_idx:
34 if idx != -1 and dev_mat[len(dev_mat) - idx - 1] != 1:
35 shard_num *= dev_mat[len(dev_mat) - idx - 1]
36 shard_dim.append(shard_num)
37 else:
38 if dev_idx != -1 and dev_mat[len(dev_mat) - dev_idx - 1] != 1:
39 shard_dim.append(dev_mat[len(dev_mat) - dev_idx - 1])
40 else:
41 shard_dim.append(1)
42 return shard_dim
44 def _check_layout(self, layout, begin, end, shape):
45 """check whether layout is valid"""
46 if len(layout) != 1:
47 raise ValueError(f"Layout must be a tuple of length 1, but got {len(layout)}")
48 layout = layout[0]
49 shard_dim = self._is_shard_dim(layout)
50 for i, _ in enumerate(begin):
51 if (shard_dim[i] != 1 and end[i] - begin[i] != shape[i]) and shape[i] != -1:
52 raise ValueError(
53 f"Slice: When a dimension({i}) is not fully fetched, the dimension can not be split now, "
54 f"the begin is {begin}, the end is {end}, the shape is {shape}, layout is {layout.to_dict()}")
55 return shard_dim
57 def infer_layout(self, layouts, extra_args):
58 """
59 Infer output layout for slice operator. The shard dim must be fully fetched.
61 Args:
62 layouts (Layout): Layout of input x
63 extra_args: (begin, end, global shape)
65 Returns:
66 layout (Layout): the out layout
67 new_begin (tuple): begin after modification
68 new_end (tuple): end after modification
69 """
70 begin = extra_args[0]
71 end = extra_args[1]
72 global_shape = extra_args[2]
73 shard_dim = self._check_layout(layouts, begin, end, global_shape)
74 new_begin = tuple(begin[i] // shard_dim[i] for i in range(len(begin)))
75 new_end = tuple(end[i] // shard_dim[i] for i in range(len(end)))
76 return layouts[0], new_begin, new_end