Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_repeat.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Repeat operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class RepeatDistributedOp(DistributedOp): 

24 """Distributed implementation for torch.Tensor.repeat.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for torch.Tensor.repeat. 

29 

30 PyTorch semantics: 

31 - Repeats this tensor along the specified dimensions. 

32 - If the number of repeat dimensions is larger than the tensor dimensions, 

33 the tensor is implicitly unsqueezed at the front. 

34 - The number of repeat dimensions cannot be smaller than the tensor dimensions. 

35 - Dimensions being repeated (>1 or 0) MUST be unsharded. 

36 

37 Args: 

38 layouts (tuple): Layouts of inputs. Expected: 

39 layouts[0] (Layout): Input tensor layout (required). 

40 extra_args (tuple/list): Should contain the repeat sizes. 

41 

42 Returns: 

43 Layout: Output tensor layout with: 

44 - New prepended dimensions: unsharded (-1) 

45 - Repeated existing dimensions (size != 1): unsharded (-1) 

46 - Preserved existing dimensions (size == 1): original sharding preserved 

47 """ 

48 if not layouts or layouts[0] is None: 

49 raise ValueError( 

50 f"Operation {self.op_name}: repeat requires a valid input tensor layout." 

51 ) 

52 

53 input_layout = layouts[0] 

54 in_tensor_map = input_layout.tensor_map 

55 input_ndim = len(in_tensor_map) 

56 

57 if not extra_args or len(extra_args) < 1: 

58 raise ValueError( 

59 f"Operation {self.op_name}: repeat requires repeat sizes in extra_args." 

60 ) 

61 

62 # Robustly handle sizes unpacking (e.g., if args are packed as a single tuple) 

63 if len(extra_args) == 1 and isinstance(extra_args[0], (tuple, list)): 

64 flat_args = extra_args[0] 

65 else: 

66 flat_args = extra_args 

67 

68 # Normalize repeat sizes to tuple of ints 

69 repeats = [] 

70 for arg in flat_args: 

71 if not isinstance(arg, int): 

72 arg = int(arg) 

73 repeats.append(arg) 

74 repeats = tuple(repeats) 

75 output_ndim = len(repeats) 

76 

77 num_new_dims = output_ndim - input_ndim 

78 output_map = [] 

79 

80 # Rule 1: New prepended dimensions are always unsharded 

81 for _ in range(num_new_dims): 

82 output_map.append(-1) 

83 

84 # Rule 2: Process existing dimensions 

85 for i in range(input_ndim): 

86 repeat_idx = num_new_dims + i 

87 repeat_times = repeats[repeat_idx] 

88 

89 if repeat_times == 1: 

90 # If the dimension is not repeated, keep the original sharding 

91 output_map.append(in_tensor_map[i]) 

92 else: 

93 # If the dimension is repeated (or zeroed), it cannot be currently sharded 

94 if in_tensor_map[i] != -1: 

95 raise ValueError( 

96 f"Operation {self.op_name}: Cannot repeat dimension {i} which is sharded. " 

97 f"Please redistribute (unshard) the tensor along this dimension first." 

98 ) 

99 # Repeated dimension remains unsharded in output 

100 output_map.append(-1) 

101 

102 # Construct output layout mapping 

103 mesh_shape = input_layout.mesh_shape 

104 alias_name = input_layout.alias_name 

105 rank_list = input_layout.rank_list 

106 

107 def idx_to_alias(idx, aliases): 

108 """Convert layout index back to alias string mapping""" 

109 if idx == -1: 

110 return "None" 

111 return aliases[len(aliases) - idx - 1] 

112 

113 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map) 

114 

115 # Instantiate new layout 

116 output_layout = Layout( 

117 mesh_shape=mesh_shape, 

118 alias_name=alias_name, 

119 rank_list=rank_list 

120 ) 

121 output_layout = output_layout(*output_alias_map) 

122 

123 return output_layout