Coverage for hyper_parallel / core / shard / ops / parallel_gather.py: 50%

131 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Gather operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class IndexSelectDistributedOp(DistributedOp): 

24 """Distributed implementation for Index Select operator.""" 

25 

26 def infer_layout(self, layouts, extra_args): 

27 """ 

28 Infer output layouts for Index Select operations. 

29 

30 Args: 

31 layouts: Layouts of input tensors 

32 extra_args: extra_args of input tensors 

33 

34 Returns: 

35 tuple: Layout for output tensor. 

36 

37 Raises: 

38 ValueError: If input layouts are not compatible or have partial status. 

39 """ 

40 # Check partial inputs 

41 if not self._allow_partial_inputs: 

42 self._check_partial_inputs(layouts) 

43 

44 # Check 

45 if len(layouts) != 3: 

46 raise ValueError(f"Gather ops requires 3 layouts, but {len(layouts)}") 

47 if len(extra_args) != 1: 

48 raise ValueError(f"Gather ops requires 1 extra args, but {len(extra_args)}") 

49 

50 # Parse layout info 

51 p_layout, i_layout = layouts[0], layouts[2] 

52 axis, batch_dims = extra_args[0], 0 

53 

54 p_tensor_map = p_layout.alias_tensor_map 

55 i_tensor_map = i_layout.alias_tensor_map 

56 

57 # Create output layout 

58 if p_tensor_map[axis] != "None": 

59 raise ValueError( 

60 f"Operation {self.op_name}: Cannot perform sharding on params along the axis" 

61 ) 

62 

63 if len(i_tensor_map) != 1: 

64 raise ValueError( 

65 f"Operation {self.op_name}: index is not a one-dimensional Tensor" 

66 ) 

67 

68 if axis < -len(p_tensor_map) or axis >= len(p_tensor_map): 

69 raise ValueError( 

70 f"Operation {self.op_name}: dim value is out of valid range" 

71 ) 

72 

73 output_tensor_map = ( 

74 p_tensor_map[:axis] + i_tensor_map[batch_dims:] + p_tensor_map[axis + 1 :] 

75 ) 

76 output_layout = i_layout 

77 output_layout = Layout( 

78 mesh_shape=output_layout.mesh_shape, 

79 alias_name=output_layout.alias_name, 

80 rank_list=output_layout.rank_list, 

81 ) 

82 output_layout = output_layout(*output_tensor_map) 

83 return output_layout 

84 

85 

86class GatherDistributedOp(DistributedOp): 

87 """Distributed implementation for Gather operator.""" 

88 

89 def infer_layout(self, layouts, extra_args): 

90 """ 

91 Infer output layouts for Gather operations. 

92 

93 Args: 

94 layouts: Layouts of input tensors 

95 extra_args: extra_args of input tensors 

96 

97 Returns: 

98 tuple: Layout for output tensor. 

99 

100 Raises: 

101 ValueError: If input layouts are not compatible or have partial status. 

102 """ 

103 # Check partial inputs 

104 if not self._allow_partial_inputs: 

105 self._check_partial_inputs(layouts) 

106 

107 # Check 

108 if len(layouts) != 3: 

109 raise ValueError(f"Gather ops requires 3 layouts, but {len(layouts)}") 

110 if len(extra_args) != 1: 

111 raise ValueError(f"Gather ops requires 1 extra args, but {len(extra_args)}") 

112 

113 # Parse layout info 

114 p_layout, i_layout = layouts[0], layouts[2] 

115 axis = extra_args[0] 

116 

117 p_tensor_map = p_layout.alias_tensor_map 

118 i_tensor_map = i_layout.alias_tensor_map 

119 

120 # Create output layout 

121 if p_tensor_map[axis] != "None": 

122 raise ValueError( 

123 f"Operation {self.op_name}: Cannot perform sharding on params along the axis" 

124 ) 

125 

126 if len(p_tensor_map) != len(i_tensor_map): 

127 raise ValueError( 

128 f"Operation {self.op_name}: input and index must have the same number of dimensions" 

129 ) 

130 

131 if axis < -len(p_tensor_map) or axis >= len(p_tensor_map): 

132 raise ValueError( 

133 f"Operation {self.op_name}: dim value is out of valid range" 

134 ) 

135 

136 output_tensor_map = i_tensor_map 

137 output_layout = i_layout 

138 output_layout = Layout( 

139 mesh_shape=output_layout.mesh_shape, 

140 alias_name=output_layout.alias_name, 

141 rank_list=output_layout.rank_list, 

142 ) 

143 output_layout = output_layout(*output_tensor_map) 

144 return output_layout 

145 

146 

147class GatherNdDistributedOp(DistributedOp): 

148 """Distributed implementation for GatherNd operator.""" 

149 

150 def infer_layout(self, layouts, extra_args): 

151 """ 

152 Infer output layout for GatherNd. 

153 

154 For GatherNd: out.shape = indices.shape[:-1] + input_x.shape[K:], where K = indices.shape[-1]. 

155 

156 This implementation: 

157 - Inherits sharding from indices[:-1]. 

158 - Allows sharding on input_x trailing dims input_x[K:]. 

159 - Requires input_x[:K] to be replicated ("None") if input_layout is provided. 

160 - Requires indices[-1] (K dim) to be replicated ("None"). 

161 

162 Output Layout: 

163 output_tensor_map = indices_tensor_map[:-1] + input_tensor_map[K:] 

164 If input_layout is None, input trailing dims are treated as replicated ("None"). 

165 """ 

166 input_layout, indices_layout = self._parse_input_layouts(layouts) 

167 

168 input_shape, indices_shape = self._get_input_shapes(extra_args) 

169 k, trail_rank = self._get_k_and_trailing_rank(input_shape, indices_shape) 

170 

171 input_tensor_map, indices_tensor_map = self._validate_tensor_maps( 

172 input_layout, indices_layout, k 

173 ) 

174 

175 # Output sharding: inherit indices[:-1] + input_x[K:]. 

176 if input_tensor_map is None: 

177 output_tensor_map = tuple(indices_tensor_map[:-1]) + ("None",) * trail_rank 

178 else: 

179 output_tensor_map = tuple(indices_tensor_map[:-1]) + tuple(input_tensor_map[k:]) 

180 

181 output_layout = Layout( 

182 mesh_shape=indices_layout.mesh_shape, 

183 alias_name=indices_layout.alias_name, 

184 rank_list=indices_layout.rank_list, 

185 ) 

186 

187 if output_tensor_map: 

188 output_layout = output_layout(*output_tensor_map) 

189 else: 

190 output_layout = output_layout("None") 

191 

192 return output_layout 

193 

194 def _parse_input_layouts(self, layouts): 

195 """Parse and validate input layouts.""" 

196 if len(layouts) < 2: 

197 raise ValueError( 

198 f"Operation {self.op_name} requires at least 2 input layouts, but got {len(layouts)}" 

199 ) 

200 

201 input_layout, indices_layout = layouts[0], layouts[1] 

202 

203 # Extra inputs are allowed only when they are non-tensor args (layout is None). 

204 for extra_layout in layouts[2:]: 

205 if extra_layout is not None: 

206 raise ValueError( 

207 f"Operation {self.op_name} only supports 2 tensor inputs, but got extra tensor layout: " 

208 f"{extra_layout}" 

209 ) 

210 

211 # For GatherNd: input_layout can be None (treated as fully replicated), but indices_layout must exist. 

212 if indices_layout is None or not hasattr(indices_layout, "alias_tensor_map"): 

213 raise ValueError(f"Operation {self.op_name}: Indices layout cannot be None") 

214 

215 return input_layout, indices_layout 

216 

217 def _validate_tensor_maps(self, input_layout, indices_layout, k): 

218 """Validate tensor maps constraints for GatherNd.""" 

219 indices_tensor_map = indices_layout.alias_tensor_map 

220 

221 # Validate: indices tensor_map must exist and last dimension cannot be split. 

222 if not indices_tensor_map: 

223 raise ValueError(f"Operation {self.op_name}: indices tensor_map cannot be empty") 

224 

225 last_axis = indices_tensor_map[-1] 

226 if not self._is_none_axis(last_axis): 

227 raise ValueError( 

228 f"Operation {self.op_name}: The last dimension of indices cannot be split. " 

229 f"Got indices[-1] = {last_axis}" 

230 ) 

231 

232 # Validate input only when layout is provided. 

233 input_tensor_map = None 

234 if input_layout is not None: 

235 input_tensor_map = input_layout.alias_tensor_map 

236 

237 if k > len(input_tensor_map): 

238 raise ValueError( 

239 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank " 

240 f"({len(input_tensor_map)})" 

241 ) 

242 

243 # Indexed dims [0:K) must be replicated. 

244 for axis_name in input_tensor_map[:k]: 

245 if not self._is_none_axis(axis_name): 

246 raise ValueError( 

247 f"Operation {self.op_name}: input_x cannot be split on indexed dims [0:{k}). " 

248 f"These dims must be 'None', but got tensor_map: {input_tensor_map}" 

249 ) 

250 

251 return input_tensor_map, indices_tensor_map 

252 

253 def _get_input_shapes(self, extra_args): 

254 """Get input and indices shapes from extra_args (WithShape suffix required).""" 

255 input_shapes = None 

256 if extra_args and hasattr(extra_args[-1], "__len__") and len(extra_args[-1]) >= 2: 

257 input_shapes = extra_args[-1] 

258 

259 if input_shapes is None: 

260 raise ValueError( 

261 f"Operation {self.op_name}: missing input_shapes in extra_args. " 

262 f"Please configure yaml with infer_layout_suffix: WithShape." 

263 ) 

264 

265 input_shape = input_shapes[0] 

266 indices_shape = input_shapes[1] 

267 if input_shape is None or indices_shape is None: 

268 raise ValueError(f"Operation {self.op_name}: input_shapes contains None: {input_shapes}") 

269 

270 input_shape = self._normalize_shape(input_shape, "input") 

271 indices_shape = self._normalize_shape(indices_shape, "indices") 

272 

273 if len(indices_shape) < 1: 

274 raise ValueError(f"Operation {self.op_name}: indices shape invalid: {indices_shape}") 

275 

276 return input_shape, indices_shape 

277 

278 def _normalize_shape(self, shape, name): 

279 """Normalize shape-like object to tuple of int.""" 

280 try: 

281 norm = tuple(shape) 

282 except TypeError as err: 

283 raise ValueError(f"Operation {self.op_name}: {name} shape is not iterable: {shape}") from err 

284 

285 try: 

286 norm = tuple(int(dim) for dim in norm) 

287 except (TypeError, ValueError) as err: 

288 raise ValueError(f"Operation {self.op_name}: {name} shape contains non-integer dims: {norm}") from err 

289 

290 return norm 

291 

292 def _get_k_and_trailing_rank(self, input_shape, indices_shape): 

293 """Compute K and trailing rank = len(input_shape) - K, where K is indices_shape[-1].""" 

294 k = indices_shape[-1] 

295 try: 

296 k = int(k) 

297 except (TypeError, ValueError) as err: 

298 raise ValueError(f"Operation {self.op_name}: indices last dim (K) is invalid: {k}") from err 

299 

300 if k <= 0: 

301 raise ValueError(f"Operation {self.op_name}: indices last dim (K) must be positive, but got {k}") 

302 

303 trail_rank = len(input_shape) - k 

304 if trail_rank < 0: 

305 raise ValueError( 

306 f"Operation {self.op_name}: indices last dim (K={k}) is larger than input rank ({len(input_shape)})" 

307 ) 

308 

309 return k, trail_rank 

310 

311 def _is_none_axis(self, axis_name): 

312 """ 

313 Check if an axis name represents no sharding. 

314 """ 

315 if axis_name == "None": 

316 return True 

317 

318 if isinstance(axis_name, tuple): 

319 return all(name == "None" for name in axis_name) 

320 

321 return False