Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_repeat_interleave.py: 85%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15 

16""" 

17Distributed implementation for RepeatInterleave operator. 

18""" 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class RepeatInterleaveDistributedOp(DistributedOp): 

24 """Distributed implementation for torch.repeat_interleave.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for RepeatInterleave operator. 

29 

30 RepeatInterleave: output = repeat_interleave(input, repeats, dim) 

31  

32 Rules: 

33 1. dim = None if not specified. 

34 2. The dimension `dim` MUST be unsharded to ensure global repeat_interleave correctness. 

35 3. Output layout usually same as input, but shape changes. 

36 

37 Args: 

38 layouts (tuple): Layouts of inputs. Expected: 

39 layouts[0] (Layout): Input tensor layout (required). 

40 extra_args (tuple, optional): Contains repeats and dim. Expected: 

41 extra_args[0] (int or Tensor): Number of repeats or Tensor 

42 extra_args[1] (int, optional): Dimension to repeat. Defaults to -1. 

43 

44 Returns: 

45 tuple: Layouts for values. 

46 """ 

47 if not layouts or layouts[0] is None: 

48 raise ValueError("repeat_interleave requires a valid input tensor layout.") 

49 

50 input_layout = layouts[0] 

51 in_tensor_map = input_layout.tensor_map 

52 dim = None # The dimension along which to repeat values. 

53 # By default, use the flattened input array, and return a flat output array. 

54 if len(extra_args) >= 2 and extra_args[1] is not None: 

55 dim = extra_args[1] 

56 if dim is None: 

57 sharded_dims = [i for i, shard in enumerate(in_tensor_map) if shard != -1] 

58 if not sharded_dims: # not shard 

59 output_tensor_map = [-1] 

60 # Only can shard on the first dimension. 

61 elif sharded_dims == [0] and in_tensor_map[0] != -1: 

62 output_tensor_map = [in_tensor_map[0]] 

63 else: 

64 # Other dims must NOT be sharded. 

65 raise ValueError( 

66 f"Operation {self.op_name}: Cannot flatten tensor when dim=None." 

67 ) 

68 

69 def idx_to_alias(idx, aliases): 

70 if idx == -1: 

71 return "None" 

72 return aliases[len(aliases) - idx - 1] 

73 output_map = tuple(idx_to_alias(idx, input_layout.alias_name) for idx in output_tensor_map) 

74 

75 output_layout = Layout( 

76 mesh_shape=input_layout.mesh_shape, 

77 alias_name=input_layout.alias_name, 

78 rank_list=input_layout.rank_list 

79 ) 

80 

81 return output_layout(*output_map) 

82 input_dim = len(in_tensor_map) 

83 if dim < 0: 

84 dim = input_dim + dim 

85 # Check if dimension is within valid range 

86 if not 0 <= dim < input_dim: 

87 raise ValueError(f"Dimension out of range (expected to be in [0, {input_dim}), but got {dim}).") 

88 # The chosen dim must NOT be sharded 

89 if in_tensor_map[dim] != -1: 

90 raise ValueError( 

91 f"Operation {self.op_name}: Cannot perform sharding on params along the chosen dim" 

92 ) 

93 # Output layout same as input layout (shape change does not affect sharding pattern) 

94 return input_layout