Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_split.py: 73%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Split operator. 

17""" 

18 

19import math 

20from .parallel_ops import DistributedOp 

21 

22 

23class SplitWithSizeDistributedOp(DistributedOp): 

24 """Distributed implementation for SplitWithSize operator.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layouts for Split operator. 

29 

30 Rules: 

31 1. Shared axis can not be split. 

32 

33 Args: 

34 layouts (Layout): Layout of input tensor 

35 extra_args (list): split size or sections, axis, input shape 

36 

37 Returns: 

38 tuple: Layouts for output tensors 

39 """ 

40 

41 input_layout = layouts[0] 

42 axis = extra_args[1] 

43 # Check shared axis can not be split. 

44 tensor_map = input_layout.tensor_map 

45 if tensor_map[axis] != -1: 

46 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

47 

48 split_sections = extra_args[0] 

49 output_num = len(split_sections) 

50 output_layouts = (input_layout,) * output_num 

51 return output_layouts 

52 

53 

54class SplitWithSizeViewDistributedOp(DistributedOp): 

55 """Distributed implementation for SplitWithSizeView operator.""" 

56 

57 def infer_layout(self, layouts, extra_args=None): 

58 """ 

59 Infer output layouts for SplitWithSizeView operator. 

60 

61 Rules: 

62 1. Shared axis can not be split. 

63 

64 Args: 

65 layouts (Layout): Layout of input tensor 

66 extra_args (list): split size or sections, axis, input shape 

67 

68 Returns: 

69 tuple: Layouts for output tensors 

70 """ 

71 

72 input_layout = layouts[0] 

73 axis = extra_args[1] 

74 # Check shared axis can not be split. 

75 tensor_map = input_layout.tensor_map 

76 if tensor_map[axis] != -1: 

77 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

78 

79 split_sections = extra_args[0] 

80 output_num = len(split_sections) 

81 output_layouts = (input_layout,) * output_num 

82 return output_layouts 

83 

84 

85class SplitDistributedOp(DistributedOp): 

86 """Distributed implementation for Split operator.""" 

87 

88 def infer_layout(self, layouts, extra_args=None): 

89 """ 

90 Infer output layouts for Split operator. 

91 

92 Rules: 

93 1. Shared axis can not be split. 

94 2. Default: dim = 0 if not specified. 

95 

96 Args: 

97 layouts (Layout): Layout of input tensor 

98 extra_args (list): split size or sections, axis, input shape. Expected: 

99 extra_args[0]: split_size (required) 

100 extra_args[1]: axis (optional) 

101 extra_args[2][0]: input_shape 

102 

103 Returns: 

104 tuple: Layouts for output tensors 

105 """ 

106 

107 if not layouts or layouts[0] is None: 

108 raise ValueError("split requires a valid input tensor layout.") 

109 input_layout = layouts[0] 

110 

111 if len(extra_args) == 2: 

112 split_size = extra_args[0] 

113 axis = 0 # default 

114 input_shape = extra_args[1][0] 

115 elif len(extra_args) == 3: 

116 split_size = extra_args[0] 

117 axis = extra_args[1] 

118 input_shape = extra_args[2][0] 

119 else: 

120 raise ValueError("Split ops extra_args requires 'axis' and contains 'output_num' optionally.") 

121 

122 tensor_map = input_layout.tensor_map 

123 input_dim = len(tensor_map) 

124 if axis < 0: 

125 axis = input_dim + axis 

126 if not 0 <= axis < input_dim: 

127 raise ValueError(f"Dimension out of range (expected [0, {input_dim}), got {axis}).") 

128 

129 # Check shared axis can not be split. 

130 if tensor_map[axis] != -1: 

131 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

132 

133 output_num = 1 

134 if isinstance(split_size, int): 

135 output_num = math.ceil(input_shape[axis] / split_size) 

136 elif isinstance(split_size, (list, tuple)): 

137 output_num = len(split_size) 

138 

139 output_layouts = (input_layout,) * output_num 

140 return output_layouts 

141 

142 

143class SplitTensorDistributedOp(DistributedOp): 

144 """Distributed implementation for SplitTensor operator.""" 

145 

146 def infer_layout(self, layouts, extra_args=None): 

147 """ 

148 Infer output layouts for Split operator. 

149 

150 Rules: 

151 1. Shared axis can not be split. 

152 

153 Args: 

154 layouts (Layout): Layout of input tensor 

155 extra_args (list): split size or sections, axis, input shape 

156 

157 Returns: 

158 tuple: Layouts for output tensors 

159 """ 

160 

161 input_layout = layouts[0] 

162 axis = extra_args[1] 

163 # Check shared axis can not be split. 

164 tensor_map = input_layout.tensor_map 

165 if tensor_map[axis] != -1: 

166 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

167 

168 split_size = extra_args[0] 

169 input_shape = extra_args[2][0] 

170 output_num = input_shape[axis] // split_size 

171 if input_shape[axis] % split_size != 0: 

172 output_num += 1 

173 

174 output_layouts = (input_layout,) * output_num 

175 return output_layouts 

176 

177 

178class SplitTensorViewDistributedOp(DistributedOp): 

179 """Distributed implementation for SplitTensorView operator.""" 

180 

181 def infer_layout(self, layouts, extra_args=None): 

182 """ 

183 Infer output layouts for SplitTensorView operator. 

184 

185 Rules: 

186 1. Shared axis can not be split. 

187 

188 Args: 

189 layouts (Layout): Layout of input tensor 

190 extra_args (list): split size or sections, axis, input shape 

191 

192 Returns: 

193 tuple: Layouts for output tensors 

194 """ 

195 

196 input_layout = layouts[0] 

197 axis = extra_args[1] 

198 # Check shared axis can not be split. 

199 tensor_map = input_layout.tensor_map 

200 if tensor_map[axis] != -1: 

201 raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") 

202 

203 split_size = extra_args[0] 

204 input_shape = extra_args[2][0] 

205 output_num = input_shape[axis] // split_size 

206 if input_shape[axis] % split_size != 0: 

207 output_num += 1 

208 

209 output_layouts = (input_layout,) * output_num 

210 return output_layouts 

211 

212 

213class TensorSplitDistributedOp(DistributedOp): 

214 """Distributed implementation for tensor_split operator.""" 

215 

216 def infer_layout(self, layouts, extra_args=None): 

217 """ 

218 Infer output layouts for tensor_split operator. 

219 

220 Rules: 

221 1. Shared (sharded) axis cannot be split. 

222 2. Default: dim = 0 if not specified. 

223 

224 Args: 

225 layouts (list): Layout of the input tensor. 

226 extra_args (list): Extracted non-tensor arguments and input shapes. Expected: 

227 extra_args[0]: indices_or_sections (int, tuple, list, or 1D tensor) 

228 extra_args[1]: dim (optional, default is 0) 

229 extra_args[-1]: input_shapes (list containing the shape of the input tensor) 

230 

231 Returns: 

232 tuple: Layouts for the output tensors 

233 """ 

234 if not layouts or layouts[0] is None: 

235 raise ValueError("tensor_split requires a valid input tensor layout.") 

236 

237 input_layout = layouts[0] 

238 

239 # Parse extra_args based on the dispatcher's WithShape suffix rules 

240 if len(extra_args) == 1: 

241 indices_or_sections = extra_args[0] 

242 dim = 0 # default 

243 elif len(extra_args) == 2: 

244 indices_or_sections = extra_args[0] 

245 dim = extra_args[1] 

246 else: 

247 raise ValueError("tensor_split ops extra_args requires 'indices_or_sections' and optionally 'dim'.") 

248 

249 tensor_map = input_layout.tensor_map 

250 input_dim = len(tensor_map) 

251 

252 # Handle negative dimensions 

253 if dim < 0: 

254 dim = input_dim + dim 

255 

256 if not 0 <= dim < input_dim: 

257 raise ValueError(f"Dimension out of range (expected [0, {input_dim}), got {dim}).") 

258 

259 # Check: shared (sharded) axis cannot be split 

260 if tensor_map[dim] != -1: 

261 raise ValueError(f"Cannot perform tensor_split on sharded axis[{dim}], layout: {input_layout}") 

262 

263 # Calculate the number of output tensors based on PyTorch's tensor_split rules 

264 if isinstance(indices_or_sections, int): 

265 output_num = indices_or_sections 

266 elif isinstance(indices_or_sections, (list, tuple)): 

267 output_num = len(indices_or_sections) + 1 

268 elif hasattr(indices_or_sections, "shape") and len(indices_or_sections.shape) == 1: 

269 # Handle 1D Tensor case 

270 output_num = indices_or_sections.shape[0] + 1 

271 else: 

272 raise TypeError("tensor_split: indices_or_sections must be an integer, list, tuple, or 1D tensor.") 

273 

274 output_layouts = (input_layout,) * output_num 

275 return output_layouts