Coverage for hyper_parallel / core / shard / ops / parallel_split.py: 52%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for TopK operator. 

17""" 

18 

19import math 

20from .parallel_ops import DistributedOp 

21 

22 

23class SplitWithSizeDistributedOp(DistributedOp): 

24 """Distributed implementation for SplitWithSize operator.""" 

25 

26 def infer_layout(self, layouts, extra_args): 

27 """ 

28 Infer output layouts for Split operator. 

29 

30 Rules: 

31 1. Shared axis can not be split. 

32 

33 Args: 

34 layouts (Layout): Layout of input tensor 

35 extra_args (list): split size or sections, axis, input shape 

36 

37 Returns: 

38 tuple: Layouts for output tensors 

39 """ 

40 

41 input_layout = layouts[0] 

42 axis = extra_args[1] 

43 # Check shared axis can not be split. 

44 tensor_map = input_layout.tensor_map 

45 if tensor_map[axis] != -1: 

46 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

47 

48 split_sections = extra_args[0] 

49 output_num = len(split_sections) 

50 output_layouts = (input_layout,) * output_num 

51 return output_layouts 

52 

53 

54class SplitWithSizeViewDistributedOp(DistributedOp): 

55 """Distributed implementation for SplitWithSizeView operator.""" 

56 

57 def infer_layout(self, layouts, extra_args): 

58 """ 

59 Infer output layouts for SplitWithSizeView operator. 

60 

61 Rules: 

62 1. Shared axis can not be split. 

63 

64 Args: 

65 layouts (Layout): Layout of input tensor 

66 extra_args (list): split size or sections, axis, input shape 

67 

68 Returns: 

69 tuple: Layouts for output tensors 

70 """ 

71 

72 input_layout = layouts[0] 

73 axis = extra_args[1] 

74 # Check shared axis can not be split. 

75 tensor_map = input_layout.tensor_map 

76 if tensor_map[axis] != -1: 

77 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

78 

79 split_sections = extra_args[0] 

80 output_num = len(split_sections) 

81 output_layouts = (input_layout,) * output_num 

82 return output_layouts 

83 

84 

85class SplitDistributedOp(DistributedOp): 

86 """Distributed implementation for Split operator.""" 

87 

88 def infer_layout(self, layouts, extra_args): 

89 """ 

90 Infer output layouts for Split operator. 

91 

92 Rules: 

93 1. Shared axis can not be split. 

94 2. Default: dim = 0 if not specified. 

95 

96 Args: 

97 layouts (Layout): Layout of input tensor 

98 extra_args (list): split size or sections, axis, input shape. Expected: 

99 extra_args[0]: split_size (required) 

100 extra_args[1]: axis (optional) 

101 extra_args[2][0]: input_shape 

102 

103 Returns: 

104 tuple: Layouts for output tensors 

105 """ 

106 

107 if not layouts or layouts[0] is None: 

108 raise ValueError("split requires a valid input tensor layout.") 

109 input_layout = layouts[0] 

110 

111 if len(extra_args) == 2: 

112 split_size = extra_args[0] 

113 axis = 0 # default 

114 input_shape = extra_args[1][0] 

115 elif len(extra_args) == 3: 

116 split_size = extra_args[0] 

117 axis = extra_args[1] 

118 input_shape = extra_args[2][0] 

119 else: 

120 raise ValueError("Split ops extra_args requires 'axis' and contains 'output_num' optionally.") 

121 

122 tensor_map = input_layout.tensor_map 

123 input_dim = len(tensor_map) 

124 if axis < 0: 

125 axis = input_dim + axis 

126 if not 0 <= axis < input_dim: 

127 raise ValueError(f"Dimension out of range (expected [0, {input_dim}), got {axis}).") 

128 

129 # Check shared axis can not be split. 

130 if tensor_map[axis] != -1: 

131 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

132 

133 output_num = 1 

134 if isinstance(split_size, int): 

135 output_num = math.ceil(input_shape[axis] / split_size) 

136 elif isinstance(split_size, (list, tuple)): 

137 output_num = len(split_size) 

138 

139 output_layouts = (input_layout,) * output_num 

140 return output_layouts 

141 

142 

143class SplitTensorDistributedOp(DistributedOp): 

144 """Distributed implementation for SplitTensor operator.""" 

145 

146 def infer_layout(self, layouts, extra_args): 

147 """ 

148 Infer output layouts for Split operator. 

149 

150 Rules: 

151 1. Shared axis can not be split. 

152 

153 Args: 

154 layouts (Layout): Layout of input tensor 

155 extra_args (list): split size or sections, axis, input shape 

156 

157 Returns: 

158 tuple: Layouts for output tensors 

159 """ 

160 

161 input_layout = layouts[0] 

162 axis = extra_args[1] 

163 # Check shared axis can not be split. 

164 tensor_map = input_layout.tensor_map 

165 if tensor_map[axis] != -1: 

166 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

167 

168 split_size = extra_args[0] 

169 input_shape = extra_args[2][0] 

170 output_num = input_shape[axis] // split_size 

171 if input_shape[axis] % split_size != 0: 

172 output_num += 1 

173 

174 output_layouts = (input_layout,) * output_num 

175 return output_layouts 

176 

177 

178class SplitTensorViewDistributedOp(DistributedOp): 

179 """Distributed implementation for SplitTensorView operator.""" 

180 

181 def infer_layout(self, layouts, extra_args): 

182 """ 

183 Infer output layouts for SplitTensorView operator. 

184 

185 Rules: 

186 1. Shared axis can not be split. 

187 

188 Args: 

189 layouts (Layout): Layout of input tensor 

190 extra_args (list): split size or sections, axis, input shape 

191 

192 Returns: 

193 tuple: Layouts for output tensors 

194 """ 

195 

196 input_layout = layouts[0] 

197 axis = extra_args[1] 

198 # Check shared axis can not be split. 

199 tensor_map = input_layout.tensor_map 

200 if tensor_map[axis] != -1: 

201 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

202 

203 split_size = extra_args[0] 

204 input_shape = extra_args[2][0] 

205 output_num = input_shape[axis] // split_size 

206 if input_shape[axis] % split_size != 0: 

207 output_num += 1 

208 

209 output_layouts = (input_layout,) * output_num 

210 return output_layouts