Coverage for hyper_parallel / core / shard / ops / parallel_one_hot_ext.py: 12%

169 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for OneHotExt operator. 

17""" 

18 

19# pylint: disable=import-outside-toplevel 

20from hyper_parallel.core.layout import Layout 

21from hyper_parallel.core.placement_types import Shard, Replicate 

22from hyper_parallel.platform import get_platform 

23from .parallel_ops import DistributedOp 

24 

25platform = get_platform() 

26 

27 

28class OneHotExtDistributedOp(DistributedOp): 

29 """Distributed implementation for OneHotExt operator.""" 

30 

31 def infer_layout(self, layouts, extra_args): 

32 """ 

33 Infer output layout for OneHotExt. 

34 

35 Args: 

36 layouts (tuple): Tuple containing input layouts. 

37 extra_args (tuple): Additional arguments containing [num_classes, on_value, off_value, axis]. 

38 

39 Returns: 

40 Layout: Output layout with one-hot dimension inserted at specified axis. 

41 """ 

42 if not layouts: 

43 return None 

44 

45 indices_layout = layouts[0] 

46 if indices_layout is None or indices_layout.mesh_shape is None: 

47 raise ValueError(f"{self.op_name}: indices layout cannot be None") 

48 

49 if indices_layout.is_partial(): 

50 raise ValueError( 

51 f"{self.op_name}: indices cannot be in partial state. " 

52 f"Indices must contain complete index values for OneHot operation." 

53 ) 

54 

55 num_classes = self._get_num_classes(extra_args) 

56 self._validate_num_classes(num_classes) 

57 

58 axis = self._get_axis(extra_args) 

59 

60 in_tensor_map = indices_layout.tensor_map 

61 if not in_tensor_map: 

62 raise ValueError(f"{self.op_name}: indices tensor_map is empty") 

63 

64 self._validate_multi_dim_restriction(in_tensor_map, axis, indices_layout) 

65 self._validate_inputs_layouts(layouts) 

66 

67 out_tensor_map = self._infer_output_tensor_map(in_tensor_map, axis) 

68 out_layout = self._create_layout_from_tensor_map(indices_layout, out_tensor_map) 

69 

70 out_placements = self._tensor_map_to_placements(indices_layout, out_tensor_map) 

71 out_layout.set_placements(out_placements) 

72 

73 return out_layout 

74 

75 def get_expand_impl(self, func, output_layout, layouts, extra_args): 

76 """Get expanded implementation for OneHotExt operator.""" 

77 import mindspore as ms 

78 from mindspore import ops, Tensor 

79 

80 del output_layout 

81 

82 indices_layout = layouts[0] if layouts else None 

83 if indices_layout is None: 

84 return None 

85 

86 sharded_axes = self._get_sharded_axes(indices_layout) 

87 if not sharded_axes: 

88 return None 

89 

90 original_op = func 

91 reduce_max = ops.ReduceMax(keep_dims=False) 

92 

93 def expanded_one_hot(indices, num_classes, on_value, off_value, axis): 

94 self._validate_num_classes(num_classes) 

95 self._validate_indices_dtype(indices) 

96 

97 if num_classes != -1: 

98 return original_op(indices, num_classes, on_value, off_value, axis) 

99 

100 local_max = reduce_max(indices, ()) 

101 if not isinstance(local_max, Tensor): 

102 local_max = Tensor(local_max, ms.int64) 

103 

104 local_max_host = int(local_max.asnumpy()) 

105 if local_max_host > 2147483647: 

106 raise ValueError( 

107 f"{self.op_name}: indices max value {local_max_host} exceeds int32 range" 

108 ) 

109 

110 zero_dim = local_max.ndim == 0 

111 local_max_i32 = ops.cast(local_max, ms.int32) 

112 

113 if zero_dim: 

114 local_max_i32 = ops.expand_dims(local_max_i32, 0) 

115 

116 global_max_i32 = local_max_i32 

117 for axis_name in sharded_axes: 

118 group = indices_layout.get_comm_group_by_axis(axis_name) 

119 global_max_i32 = platform.differentiable_all_reduce( 

120 global_max_i32, "max", group 

121 ) 

122 

123 if zero_dim: 

124 global_max_i32 = ops.squeeze(global_max_i32, 0) 

125 

126 depth = int(global_max_i32.asnumpy()) + 1 

127 return original_op(indices, depth, on_value, off_value, axis) 

128 

129 return expanded_one_hot 

130 

131 def _get_num_classes(self, extra_args): 

132 """Extract num_classes from extra arguments.""" 

133 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 1: 

134 num_classes = extra_args[0] 

135 if isinstance(num_classes, int): 

136 return num_classes 

137 return -1 

138 

139 def _validate_num_classes(self, num_classes): 

140 """Validate num_classes parameter.""" 

141 if not isinstance(num_classes, int): 

142 raise TypeError( 

143 f"{self.op_name}: num_classes must be int, but got {type(num_classes).__name__}" 

144 ) 

145 if num_classes < -1: 

146 raise ValueError( 

147 f"{self.op_name}: num_classes must be >= -1, but got {num_classes}" 

148 ) 

149 

150 def _validate_indices_dtype(self, indices): 

151 """Validate indices dtype.""" 

152 import mindspore as ms 

153 

154 if indices.dtype != ms.int64: 

155 raise TypeError( 

156 f"{self.op_name}: indices dtype must be int64, but got {indices.dtype}" 

157 ) 

158 

159 def _get_sharded_axes(self, layout): 

160 """Get all device axes that are used for sharding.""" 

161 sharded_axes = set() 

162 

163 if layout is None or layout.alias_tensor_map is None: 

164 return [] 

165 

166 for dim_alias in layout.alias_tensor_map: 

167 if dim_alias == "None": 

168 continue 

169 

170 if isinstance(dim_alias, tuple): 

171 for axis_name in dim_alias: 

172 if axis_name != "None": 

173 sharded_axes.add(axis_name) 

174 else: 

175 sharded_axes.add(dim_alias) 

176 

177 return list(sharded_axes) 

178 

179 def _get_axis(self, extra_args): 

180 """Extract axis parameter from extra arguments.""" 

181 if isinstance(extra_args, (list, tuple)) and len(extra_args) >= 4: 

182 axis = extra_args[3] 

183 if isinstance(axis, int): 

184 return self._validate_axis(axis) 

185 return -1 

186 

187 def _validate_axis(self, axis): 

188 """Validate axis parameter.""" 

189 if not isinstance(axis, int): 

190 raise TypeError( 

191 f"{self.op_name}: axis must be int, but got {type(axis).__name__}" 

192 ) 

193 

194 if axis > 1 or axis < -1: 

195 raise ValueError(f"{self.op_name}: axis {axis} is out of range[-1, 1]") 

196 

197 return axis 

198 

199 def _validate_multi_dim_restriction(self, in_tensor_map, axis, indices_layout): 

200 """Validate restriction for multi-dimensional inputs.""" 

201 in_rank = len(in_tensor_map) 

202 if in_rank <= 1: 

203 return 

204 

205 if axis != -1: 

206 raise ValueError( 

207 f"{self.op_name}: when input dimension is > 1, axis must be -1, but got {axis}" 

208 ) 

209 

210 alias_map = indices_layout.alias_tensor_map 

211 for i in range(1, len(alias_map)): 

212 if alias_map[i] != "None": 

213 raise ValueError( 

214 f"{self.op_name}: when input dimension is > 1, strategy must be data parallel, " 

215 f"but dimension {i} is sharded on '{alias_map[i]}'" 

216 ) 

217 

218 def _validate_inputs_layouts(self, layouts): 

219 """Validate that non-indices inputs are fully replicated.""" 

220 for layout in layouts[1:]: 

221 if layout is None: 

222 continue 

223 alias_map = layout.alias_tensor_map 

224 if alias_map and any(x != "None" for x in alias_map): 

225 raise ValueError( 

226 f"{self.op_name}: non-indices inputs must be replicated, but got {alias_map}" 

227 ) 

228 

229 def _infer_output_tensor_map(self, in_tensor_map, axis): 

230 """Infer output tensor map by inserting one-hot dimension at specified axis.""" 

231 in_rank = len(in_tensor_map) 

232 

233 if axis in (-1, in_rank): 

234 insert_pos = in_rank 

235 else: 

236 insert_pos = axis 

237 

238 if insert_pos < 0 or insert_pos > in_rank: 

239 raise ValueError( 

240 f"{self.op_name}: axis {axis} is out of range for input with rank {in_rank}" 

241 ) 

242 

243 out_tensor_map = list(in_tensor_map) 

244 out_tensor_map.insert(insert_pos, -1) 

245 return tuple(out_tensor_map) 

246 

247 def _create_layout_from_tensor_map(self, base_layout, out_tensor_map): 

248 """Create output layout from tensor map.""" 

249 out_layout = Layout( 

250 mesh_shape=base_layout.mesh_shape, 

251 alias_name=base_layout.alias_name, 

252 rank_list=base_layout.rank_list, 

253 ) 

254 

255 out_layout.set_tensor_map(out_tensor_map) 

256 out_layout.set_alias_tensor_map( 

257 self._tensor_map_to_alias_tensor_map(base_layout, out_tensor_map) 

258 ) 

259 out_layout.update_compact_str() 

260 return out_layout 

261 

262 def _tensor_map_to_alias_tensor_map(self, base_layout, tensor_map): 

263 """Convert numeric tensor map to alias tensor map.""" 

264 alias_name = base_layout.alias_name 

265 alias_tensor_map = [] 

266 

267 for dim in tensor_map: 

268 if dim == -1: 

269 alias_tensor_map.append("None") 

270 continue 

271 

272 if isinstance(dim, tuple): 

273 names = tuple( 

274 alias_name[len(alias_name) - 1 - d] for d in dim if d != -1 

275 ) 

276 alias_tensor_map.append(names if names else "None") 

277 continue 

278 

279 alias_tensor_map.append(alias_name[len(alias_name) - 1 - dim]) 

280 

281 return tuple(alias_tensor_map) 

282 

283 def _tensor_map_to_placements(self, base_layout, tensor_map): 

284 """ 

285 Convert tensor_map to placements. 

286  

287 Args: 

288 base_layout: Base layout to get mesh dimension info 

289 tensor_map: Tensor map to convert 

290  

291 Returns: 

292 tuple: Placements tuple (Shard/Replicate for each mesh dimension) 

293 """ 

294 mesh_ndim = len(base_layout.mesh_shape) 

295 placements = [] 

296 

297 for mesh_dim_idx in range(mesh_ndim): 

298 is_sharded = False 

299 

300 for tensor_dim_idx, tensor_dim_map in enumerate(tensor_map): 

301 if tensor_dim_map == -1: 

302 continue 

303 

304 if isinstance(tensor_dim_map, tuple): 

305 if mesh_dim_idx in tensor_dim_map: 

306 placements.append(Shard(tensor_dim_idx)) 

307 is_sharded = True 

308 break 

309 elif tensor_dim_map == mesh_dim_idx: 

310 placements.append(Shard(tensor_dim_idx)) 

311 is_sharded = True 

312 break 

313 

314 if not is_sharded: 

315 placements.append(Replicate()) 

316 

317 return tuple(placements)