Coverage for hyper_parallel / core / shard / ops / parallel_repeat_interleave.py: 85%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15 

16""" 

17Distributed implementation for RepeatInterleave operator. 

18""" 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22class RepeatInterleaveDistributedOp(DistributedOp): 

23 """Distributed implementation for torch.repeat_interleave.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for RepeatInterleave operator. 

28 

29 RepeatInterleave: output = repeat_interleave(input, repeats, dim) 

30  

31 Rules: 

32 1. dim = None if not specified. 

33 2. The dimension `dim` MUST be unsharded to ensure global repeat_interleave correctness. 

34 3. Output layout usually same as input, but shape changes. 

35 

36 Args: 

37 layouts (tuple): Layouts of inputs. Expected: 

38 layouts[0] (Layout): Input tensor layout (required). 

39 extra_args (tuple, optional): Contains repeats and dim. Expected: 

40 extra_args[0] (int or Tensor): Number of repeats or Tensor 

41 extra_args[1] (int, optional): Dimension to repeat. Defaults to -1. 

42 

43 Returns: 

44 tuple: Layouts for values. 

45 """ 

46 if not layouts or layouts[0] is None: 

47 raise ValueError("repeat_interleave requires a valid input tensor layout.") 

48 

49 input_layout = layouts[0] 

50 in_tensor_map = input_layout.tensor_map 

51 dim = None # The dimension along which to repeat values. By default, use the flattened input array, and return a flat output array. 

52 if len(extra_args) >= 2 and extra_args[1] is not None: 

53 dim = extra_args[1] 

54 if dim is None: 

55 sharded_dims = [i for i, shard in enumerate(in_tensor_map) if shard != -1] 

56 if not sharded_dims: # not shard 

57 output_tensor_map = [-1] 

58 # Only can shard on the first dimension. 

59 elif sharded_dims == [0] and in_tensor_map[0] != -1: 

60 output_tensor_map = [in_tensor_map[0]] 

61 else: 

62 # Other dims must NOT be sharded. 

63 raise ValueError( 

64 f"Operation {self.op_name}: Cannot flatten tensor when dim=None." 

65 ) 

66 def idx_to_alias(idx, aliases): 

67 if idx == -1: 

68 return "None" 

69 return aliases[len(aliases) - idx - 1] 

70 output_map = tuple(idx_to_alias(idx, input_layout.alias_name) for idx in output_tensor_map) 

71 

72 output_layout = Layout( 

73 mesh_shape=input_layout.mesh_shape, 

74 alias_name=input_layout.alias_name, 

75 rank_list=input_layout.rank_list 

76 ) 

77 

78 return output_layout(*output_map) 

79 input_dim = len(in_tensor_map) 

80 if dim < 0: 

81 dim = input_dim + dim 

82 # Check if dimension is within valid range 

83 if not 0 <= dim < input_dim: 

84 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).") 

85 # The chosen dim must NOT be sharded 

86 if in_tensor_map[dim] != -1: 

87 raise ValueError( 

88 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim" 

89 ) 

90 # Output layout same as input layout (shape change does not affect sharding pattern) 

91 return input_layout