Coverage for hyper_parallel / core / shard / ops / parallel_reduce.py: 85%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Reduce operator. 

17""" 

18 

19from copy import deepcopy 

20from typing import Sequence, Union, Tuple, List 

21from hyper_parallel.core.layout import Layout 

22from hyper_parallel.platform import get_platform 

23from .parallel_ops import DistributedOp 

24platform = get_platform() 

25Tensor = platform.Tensor 

26 

27 

28StrOrTuple = Union[str, Tuple["StrOrTuple", ...], List["StrOrTuple"]] 

29 

30 

31class ReduceExtDistributedOpBase(DistributedOp): 

32 """ 

33 Base class for distributed reduce operators. 

34 

35 Args: 

36 op_name (str): Name of the operator to register. 

37 partial_type (list): List of the operator for allreduce. 

38 """ 

39 

40 def __init__(self, op_name, partial_type=None): 

41 super().__init__(op_name) 

42 if partial_type is None: 

43 partial_type = ["sum"] 

44 self.partial_type = partial_type 

45 

46 def infer_layout(self, layouts, extra_args): 

47 """ 

48 Infer output layout for reduce operator. 

49 

50 Args: 

51 layouts (tuple): Layouts of input tensor. 

52 extra_args (dict): Additional arguments (dim, keepdim). 

53 

54 Returns: 

55 tuple: Layout for output tensor. 

56 """ 

57 if not layouts: 

58 raise ValueError(f"{self.__class__.__name__} requires at least one input layout") 

59 

60 x_layout = layouts[0] 

61 

62 if x_layout.mesh_shape is None: 

63 raise ValueError("Input layouts cannot be None.") 

64 

65 # [dim, keepdim] 

66 if not extra_args: 

67 dim = None 

68 keepdim = False 

69 elif len(extra_args) == 1: 

70 dim = None 

71 keepdim = extra_args[0] 

72 else: 

73 dim, keepdim = extra_args 

74 

75 if isinstance(dim, Tensor): 

76 raise TypeError( 

77 "The `dim` argument should not be a `Tensor`. Instead, use one of the following types: " 

78 "`None`, `int`, `tuple[int]`, or `list[int]`." 

79 ) 

80 

81 # Infer the output shape based on dim and keepdim 

82 output_layout = self._infer_output_layout(x_layout, dim, keepdim) 

83 

84 return output_layout 

85 

86 def _infer_output_layout(self, x_layout, dim, keepdim): 

87 """Infer output layout for reduce operator.""" 

88 # Case 1: Handle dim as an empty tuple, meaning reduce all dimensions 

89 if dim is None: 

90 return self._handle_all_axis_reduce(x_layout, keepdim) 

91 

92 # Case 2: Handle dim as int, tuple, or list, with keepdim True or False 

93 output_layout = Layout( 

94 mesh_shape=x_layout.mesh_shape, 

95 alias_name=x_layout.alias_name, 

96 rank_list=x_layout.rank_list 

97 ) 

98 x_map = x_layout.alias_tensor_map 

99 reduce_alias, x_map = self.replace_axis_with_none(dim, x_layout, keepdim) 

100 output_layout = output_layout(*x_map) 

101 self._apply_partial(output_layout, reduce_alias) 

102 return output_layout 

103 

104 def _handle_all_axis_reduce(self, x_layout, keepdim): 

105 """Handle the case where dim is empty, meaning reduce all dimensions.""" 

106 layout = Layout( 

107 mesh_shape=x_layout.mesh_shape, 

108 alias_name=x_layout.alias_name, 

109 rank_list=x_layout.rank_list 

110 ) 

111 

112 if not keepdim: 

113 output_layout = layout() 

114 else: 

115 tensor_map = tuple(["None"] * len(x_layout.alias_tensor_map)) 

116 output_layout = layout(*tensor_map) 

117 

118 self._apply_partial(output_layout, x_layout.alias_tensor_map) 

119 return output_layout 

120 

121 def replace_axis_with_none(self, dim, x_layout, keepdim): 

122 """Replace or drop dimensions depending on keepdim.""" 

123 if not isinstance(dim, (tuple, list)): 

124 dim = [dim] 

125 else: 

126 dim = list(dim) 

127 

128 rank = len(x_layout.alias_tensor_map) 

129 for i, axis_id in enumerate(dim): 

130 if axis_id < 0: 

131 dim[i] = rank + axis_id 

132 if not isinstance(axis_id, int) or dim[i] >= rank or dim[i] < 0: 

133 raise ValueError(f"Invalid reduce axis index {axis_id} at position {i}.") 

134 

135 alias_tensor_map = x_layout.alias_tensor_map 

136 reduce_alias = [alias_tensor_map[axis_id] for axis_id in dim if 

137 alias_tensor_map[axis_id] is not None and alias_tensor_map[axis_id] != "None"] 

138 reduce_alias = self._flatten_aliases(reduce_alias) 

139 

140 if keepdim: 

141 return self._replace_keepdim(alias_tensor_map, reduce_alias) 

142 return self._replace_dropdim(alias_tensor_map, reduce_alias, dim) 

143 

144 def _flatten_aliases(self, reduce_alias): 

145 """Flatten reduce_alias into a list of atomic alias strings.""" 

146 flat = [] 

147 for alias in reduce_alias: 

148 if isinstance(alias, (tuple, list)): 

149 flat.extend(alias) 

150 else: 

151 flat.append(alias) 

152 return flat 

153 

154 def _replace_keepdim(self, alias_tensor_map, reduce_alias): 

155 """keepdim, replace reduce alias with 'None'.""" 

156 new_alias_map = [] 

157 for alias in alias_tensor_map: 

158 if isinstance(alias, (tuple, list)): 

159 new_alias = tuple("None" if item in reduce_alias else item for item in alias) 

160 new_alias_map.append(new_alias) 

161 else: 

162 if alias in reduce_alias: 

163 new_alias_map.append("None") 

164 else: 

165 new_alias_map.append(alias) 

166 new_alias_map = self._compact_tensor_map(new_alias_map) 

167 return reduce_alias, tuple(new_alias_map) 

168 

169 def _replace_dropdim(self, alias_tensor_map, reduce_alias, dim): 

170 """Compress reduce dim.""" 

171 new_alias_map = [] 

172 for i, alias in enumerate(alias_tensor_map): 

173 if i in dim: 

174 continue 

175 if isinstance(alias, (tuple, list)): 

176 new_alias = tuple(item for item in alias if item not in reduce_alias) 

177 if new_alias: 

178 new_alias_map.append(new_alias) 

179 else: 

180 if alias in reduce_alias: 

181 continue 

182 new_alias_map.append(alias) 

183 new_alias_map = self._compact_tensor_map(new_alias_map) 

184 return reduce_alias, tuple(new_alias_map) 

185 

186 def _compact_tensor_map(self, alias_map: Sequence[StrOrTuple]) -> Tuple[StrOrTuple, ...]: 

187 """Extend tensor map of 'None'.""" 

188 

189 def _compress(elem: StrOrTuple) -> StrOrTuple: 

190 if isinstance(elem, (list, tuple)): 

191 compressed = tuple(_compress(e) for e in elem) 

192 if len(compressed) == 1: 

193 return compressed[0] 

194 if all(x == 'None' for x in compressed): 

195 return 'None' 

196 return compressed 

197 return elem 

198 

199 return tuple(_compress(elem) for elem in alias_map) 

200 

201 def _apply_partial(self, out_layout, alias): 

202 """Apply all partial to given alias (string, tuple, list).""" 

203 if alias == "None": 

204 return 

205 if isinstance(alias, (tuple, list)): 

206 for elem in alias: 

207 self._apply_partial(out_layout, elem) 

208 else: 

209 for ops in self.partial_type: 

210 out_layout.set_partial_by_dev_axis(alias, ops) 

211 

212 

213class SumExtDistributedOp(ReduceExtDistributedOpBase): 

214 """Distributed implementation for SumExt operator.""" 

215 

216 def __init__(self, op_name="SumExt"): 

217 super().__init__(op_name, partial_type=["sum"]) 

218 

219 

220class MeanExtDistributedOp(ReduceExtDistributedOpBase): 

221 """Distributed implementation for MeanExt operator.""" 

222 

223 def __init__(self, op_name="MeanExt"): 

224 super().__init__(op_name, partial_type=["avg"]) 

225 

226 

227class ReduceMaxDistributedOp(ReduceExtDistributedOpBase): 

228 """Distributed implementation for ReduceMax operator.""" 

229 

230 def __init__(self, op_name="ReduceMax"): 

231 super().__init__(op_name, partial_type=["max"]) 

232 

233class ProdExtDistributedOp(ReduceExtDistributedOpBase): 

234 """ 

235 Distributed implementation for ProdExt operator (product of all elements or along a dim). 

236 Compatible with torch.prod arguments. 

237 """ 

238 

239 def __init__(self, op_name="prod"): 

240 super().__init__(op_name, partial_type=["prod"]) 

241 

242class AllExtDistributedOp(ReduceExtDistributedOpBase): 

243 """ 

244 Distributed implementation for All operator 

245 Returns the cumulative sum of elements of input in the dimension dim. 

246 """ 

247 

248 def __init__(self, op_name="all"): 

249 super().__init__(op_name, partial_type=["all"]) 

250 

251class MaxDistributedOp(ReduceExtDistributedOpBase): 

252 """ 

253 Distributed implementation for Pytorch style Max operator. 

254  

255 Supports three Pytorch behaviors: 

256 1. torch.max(input) -> Global reduction (returns single Tensor) 

257 2. torch.max(input, dim, keepdim=False) -> Dimension reduction (returns (values, indices)) 

258 3. torch.max(input, other) -> Element-wise max (returns single Tensor) 

259 """ 

260 

261 def __init__(self, op_name="max"): 

262 super().__init__(op_name, partial_type=["max"]) 

263 

264 def infer_layout(self, layouts, extra_args): 

265 """ 

266 Infer output layouts for torch.max. 

267 """ 

268 # Filter out None layouts (corresponding to non-tensor args like dim, keepdim) 

269 valid_layouts = [l for l in layouts if l is not None] 

270 

271 if not valid_layouts: 

272 raise ValueError("MaxDistributedOp requires at least one input layout") 

273 

274 # Case 1: Element-wise max (e.g., torch.max(a, b)) 

275 if len(valid_layouts) > 1: 

276 # Element-wise max returns a single tensor, so return a single Layout object. 

277 return valid_layouts[0] 

278 

279 # Case 2 & 3: Reduction max 

280 x_layout = valid_layouts[0] 

281 if x_layout.mesh_shape is None: 

282 raise ValueError("Input layouts cannot be None.") 

283 

284 dim = None 

285 keepdim = False 

286 

287 if extra_args: 

288 dim = extra_args[0] 

289 if len(extra_args) > 1: 

290 keepdim = extra_args[1] 

291 

292 if isinstance(dim, Tensor): 

293 raise TypeError( 

294 "The `dim` argument should not be a `Tensor`. Instead, use one of the following types: " 

295 "`None`, `int`, `tuple[int]`, or `list[int]`." 

296 ) 

297 

298 values_layout = self._infer_output_layout(x_layout, dim, keepdim) 

299 

300 if dim is None: 

301 # torch.max(input) -> Single Tensor 

302 # OpDispatcher logic: 

303 # if isinstance(py_output, tuple): ... 

304 # else: DTensor.from_local(py_output, output_layout.mesh, ...) 

305 # So here output_layout MUST be a Layout object, not a tuple. 

306 return values_layout 

307 

308 # torch.max(input, dim) -> (values, indices) 

309 # OpDispatcher logic expects tuple of layouts. 

310 indices_layout = deepcopy(values_layout) 

311 return (values_layout, indices_layout)