Coverage for hyper_parallel / core / shard / ops / parallel_slice.py: 15%

34 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Slice operator. 

17""" 

18# pylint: disable=E0402 

19from .parallel_ops import DistributedOp 

20 

21 

22class SliceDistributedOp(DistributedOp): 

23 """Distributed implementation for MatMul operator.""" 

24 

25 def _is_shard_dim(self, layout): 

26 """return the shard num in each dim""" 

27 shard_dim = [] 

28 dev_mat = layout.mesh_shape 

29 tensor_map = layout.tensor_map 

30 for dev_idx in tensor_map: 

31 if isinstance(dev_idx, (tuple, list)): 

32 shard_num = 1 

33 for idx in dev_idx: 

34 if idx != -1 and dev_mat[len(dev_mat) - idx - 1] != 1: 

35 shard_num *= dev_mat[len(dev_mat) - idx - 1] 

36 shard_dim.append(shard_num) 

37 else: 

38 if dev_idx != -1 and dev_mat[len(dev_mat) - dev_idx - 1] != 1: 

39 shard_dim.append(dev_mat[len(dev_mat) - dev_idx - 1]) 

40 else: 

41 shard_dim.append(1) 

42 return shard_dim 

43 

44 def _check_layout(self, layout, begin, end, shape): 

45 """check whether layout is valid""" 

46 if len(layout) != 1: 

47 raise ValueError(f"Layout must be a tuple of length 1, but got {len(layout)}") 

48 layout = layout[0] 

49 shard_dim = self._is_shard_dim(layout) 

50 for i, _ in enumerate(begin): 

51 if (shard_dim[i] != 1 and end[i] - begin[i] != shape[i]) and shape[i] != -1: 

52 raise ValueError( 

53 f"Slice: When a dimension({i}) is not fully fetched, the dimension can not be split now, " 

54 f"the begin is {begin}, the end is {end}, the shape is {shape}, layout is {layout.to_dict()}") 

55 return shard_dim 

56 

57 def infer_layout(self, layouts, extra_args): 

58 """ 

59 Infer output layout for slice operator. The shard dim must be fully fetched. 

60 

61 Args: 

62 layouts (Layout): Layout of input x 

63 extra_args: (begin, end, global shape) 

64 

65 Returns: 

66 layout (Layout): the out layout 

67 new_begin (tuple): begin after modification 

68 new_end (tuple): end after modification 

69 """ 

70 begin = extra_args[0] 

71 end = extra_args[1] 

72 global_shape = extra_args[2] 

73 shard_dim = self._check_layout(layouts, begin, end, global_shape) 

74 new_begin = tuple(begin[i] // shard_dim[i] for i in range(len(begin))) 

75 new_end = tuple(end[i] // shard_dim[i] for i in range(len(end))) 

76 return layouts[0], new_begin, new_end