Coverage for hyper_parallel / platform / platform.py: 73%

259 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""framework platform api""" 

16import os 

17from datetime import timedelta 

18from enum import auto, Enum 

19from typing import Optional, Any 

20 

21import numpy as np 

22# Environment variable name used to specify the AI framework platform to use 

23HYPER_PARALLEL_PLATFORM = "HYPER_PARALLEL_PLATFORM" 

24 

25# Identifier for the MindSpore framework 

26HYPER_PARALLEL_PLATFORM_MINDSPORE = "mindspore" 

27 

28# Identifier for the PyTorch framework 

29HYPER_PARALLEL_PLATFORM_TORCH = "torch" 

30 

31 

32class PlatformType(Enum): 

33 """Enumeration class for AI framework platform types. 

34 

35 Used to identify different deep learning framework platform types. 

36 """ 

37 MINDSPORE = auto() 

38 PYTORCH = auto() 

39 

40 

41# Global platform instance, used to cache the created platform object 

42platform = None 

43 

44 

45def get_mindspore_platform(): 

46 """Create mindspore platform""" 

47 # pylint: disable=C0415 

48 from hyper_parallel.platform.mindspore.platform import MindSporePlatform 

49 global platform 

50 platform = MindSporePlatform() 

51 return platform 

52 

53 

54def get_torch_platform(): 

55 """Create torch platform""" 

56 # pylint: disable=C0415 

57 from hyper_parallel.platform.torch.platform import TorchPlatform 

58 global platform 

59 platform = TorchPlatform() 

60 return platform 

61 

62 

63def get_platform(): 

64 """Obtain a framework platform instance. 

65 

66 Returns the appropriate AI framework platform instance based on environment variables or a default priority order. 

67 The lookup priority is as follows: 

68 1. Platform specified by environment variable 

69 2. MindSpore platform (default preferred choice) 

70 3. PyTorch platform (fallback option) 

71 

72 Returns: 

73 Platform: An instance of the framework platform 

74 

75 Raises: 

76 ImportError: Raised when none of the supported frameworks are available 

77 """ 

78 if platform is not None: 

79 return platform 

80 platform_type = os.environ.get(HYPER_PARALLEL_PLATFORM) 

81 if platform_type is not None and isinstance(platform_type, str): 

82 platform_type = platform_type.lower() 

83 if platform_type == HYPER_PARALLEL_PLATFORM_MINDSPORE: 

84 return get_mindspore_platform() 

85 if platform_type == HYPER_PARALLEL_PLATFORM_TORCH: 

86 return get_torch_platform() 

87 try: 

88 return get_mindspore_platform() 

89 except ImportError: 

90 return get_torch_platform() 

91 

92 

93EXISTING_COMM_GROUPS = {} 

94 

95 

96class Platform: 

97 """Platform api""" 

98 current_grad_handle = None 

99 post_grad_handle_process = None 

100 grad_sync_stream = None 

101 

102 @staticmethod 

103 def get_rank(): 

104 raise NotImplementedError("Platform subclasses must implement get_rank") 

105 

106 @staticmethod 

107 def get_global_rank(group, group_rank): 

108 raise NotImplementedError("Platform subclasses must implement get_global_rank") 

109 

110 @staticmethod 

111 def get_world_size(): 

112 raise NotImplementedError("Platform subclasses must implement get_world_size") 

113 

114 @staticmethod 

115 def get_op_name(func): 

116 raise NotImplementedError("Platform subclasses must implement get_op_name") 

117 

118 @staticmethod 

119 def differentiable_all_gather_concat(data, group, concat_size, concat_dim): 

120 raise NotImplementedError("Platform subclasses must implement differentiable_all_gather_concat") 

121 

122 @staticmethod 

123 def chunk(data, split_dim, split_size, index): 

124 raise NotImplementedError("Platform subclasses must implement chunk") 

125 

126 @staticmethod 

127 def differentiable_all_to_all(input_data, output_shape, group): 

128 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all") 

129 

130 @staticmethod 

131 def tensor_type_cast(input_data, cast_type): 

132 raise NotImplementedError("Platform subclasses must implement tensor_type_cast") 

133 

134 @staticmethod 

135 def differentiable_all_reduce(data, op, group): 

136 raise NotImplementedError("Platform subclasses must implement differentiable_all_reduce") 

137 

138 @staticmethod 

139 def differentiable_reduce_scatter(data, dev_num, axis, op, group): 

140 raise NotImplementedError("Platform subclasses must implement differentiable_reduce_scatter") 

141 

142 @staticmethod 

143 def init_parameters(module, stage_index): 

144 """platform ms need init parameter interface""" 

145 if module is None: 

146 raise ValueError("input module must not be none.") 

147 if stage_index < 0: 

148 raise ValueError("input stage_index must be positive.") 

149 

150 @staticmethod 

151 def get_cell_construct(cell): 

152 raise NotImplementedError("Platform subclasses must implement get_cell_construct") 

153 

154 @staticmethod 

155 def get_cells_and_names(cell): 

156 raise NotImplementedError("Platform subclasses must implement get_cells_and_names") 

157 

158 @staticmethod 

159 def search_parameter_by_name(cell, param_name: str): 

160 raise NotImplementedError("Platform subclasses must implement search_parameter_by_name") 

161 

162 @staticmethod 

163 def update_parameter_by_name(cell, result: tuple, new_param) -> bool: 

164 raise NotImplementedError("Platform subclasses must implement update_parameter_by_name") 

165 

166 @staticmethod 

167 def set_layout_into_parameter(param, layout): 

168 raise NotImplementedError("Platform subclasses must implement set_layout_into_parameter") 

169 

170 @staticmethod 

171 def get_param_local_shape(param): 

172 raise NotImplementedError("Platform subclasses must implement get_param_local_shape") 

173 

174 @staticmethod 

175 def get_param_local_data(param): 

176 raise NotImplementedError("Platform subclasses must implement get_param_local_data") 

177 

178 @staticmethod 

179 def update_param_data(param, data): 

180 raise NotImplementedError("Platform subclasses must implement update_param_data") 

181 

182 @staticmethod 

183 def get_param_type_size(param): 

184 raise NotImplementedError("Platform subclasses must implement get_param_type_size") 

185 

186 @staticmethod 

187 def new_zero_parameter(param_shape, param_type, requires_grad, device): 

188 raise NotImplementedError("Platform subclasses must implement new_zero_parameter") 

189 

190 @staticmethod 

191 def new_tensor(tensor_shape, tensor_type, device): 

192 raise NotImplementedError("Platform subclasses must implement new_tensor") 

193 

194 @staticmethod 

195 def full_like(tensor, fill_value, dtype=None): 

196 raise NotImplementedError("Platform subclasses must implement full_like") 

197 

198 @staticmethod 

199 def set_tensor_requires_grad(input_tensor): 

200 raise NotImplementedError("Platform subclasses must implement set_tensor_requires_grad") 

201 

202 @staticmethod 

203 def all_gather_into_tensor(data, group_info, async_op=False): 

204 raise NotImplementedError("Platform subclasses must implement all_gather_into_tensor") 

205 

206 @staticmethod 

207 def all_reduce(data, group_info, async_op=False): 

208 raise NotImplementedError("Platform subclasses must implement all_reduce") 

209 

210 @staticmethod 

211 def broadcast(data, src, group, async_op=False): 

212 raise NotImplementedError("Platform subclasses must implement broadcast") 

213 

214 @staticmethod 

215 def isend(tensor, dst=None, group=None, tag=0): 

216 raise NotImplementedError("Platform subclasses must implement isend") 

217 

218 @staticmethod 

219 def irecv(tensor, src=None, group=None, tag=0): 

220 raise NotImplementedError("Platform subclasses must implement irecv") 

221 

222 @staticmethod 

223 def send_object_list(obj_list, dst=None, group=None): 

224 raise NotImplementedError("Platform subclasses must implement send_object_list") 

225 

226 @staticmethod 

227 def recv_object_list(obj_list, src=None, group=None): 

228 raise NotImplementedError("Platform subclasses must implement send_object_list") 

229 

230 @staticmethod 

231 def reduce_scatter_tensor(data, group_info, async_op=False): 

232 raise NotImplementedError("Platform subclasses must implement reduce_scatter_tensor") 

233 

234 @staticmethod 

235 def parameters_dict(cell): 

236 raise NotImplementedError("Platform subclasses must implement parameters_dict") 

237 

238 @staticmethod 

239 def save_checkpoint(cell, file_path: str) -> None: 

240 raise NotImplementedError("Platform subclasses must implement save_checkpoint") 

241 

242 @staticmethod 

243 def load_checkpoint(file_path: str) -> dict: 

244 raise NotImplementedError("Platform subclasses must implement load_checkpoint") 

245 

246 def _create_group(self, rank_list, group_name=None): 

247 raise NotImplementedError("Platform subclasses must implement _create_group") 

248 

249 def new_stream(self): 

250 raise NotImplementedError("Platform subclasses must implement new_stream") 

251 

252 def get_stream_context(self): 

253 raise NotImplementedError("Platform subclasses must implement get_stream_context") 

254 

255 @staticmethod 

256 def get_tensor_transform(): 

257 raise NotImplementedError("Platform subclasses must implement get_tensor_transform") 

258 

259 @staticmethod 

260 def construct_strided_slice(x, begin, end, stride): 

261 raise NotImplementedError("Platform subclasses must implement construct_strided_slice") 

262 

263 @staticmethod 

264 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None): 

265 raise NotImplementedError("Platform subclasses must implement micro_batch") 

266 

267 def create_group(self, rank_list, group_name=None): 

268 """create comm group with rank list""" 

269 if group_name is None: 

270 group_key = hash(tuple(rank_list)) 

271 else: 

272 group_key = group_name 

273 if group_key in EXISTING_COMM_GROUPS: 

274 return EXISTING_COMM_GROUPS[group_key] 

275 

276 group = self._create_group(rank_list, group_name) 

277 EXISTING_COMM_GROUPS[group_key] = group 

278 return group 

279 

280 def _process_current_handle(self): 

281 """wait current handle""" 

282 if Platform.current_grad_handle is None: 

283 return 

284 

285 Platform.current_grad_handle.wait() 

286 if Platform.post_grad_handle_process is None: 

287 return 

288 # pylint: disable=E1102 

289 Platform.post_grad_handle_process() 

290 

291 def set_grad_reduce_handle(self, handle, post_process=None): 

292 """wait current handle and set new handle""" 

293 if Platform.grad_sync_stream is None: 

294 Platform.grad_sync_stream = self.new_stream() 

295 stream_context = self.get_stream_context() 

296 with stream_context(Platform.grad_sync_stream): 

297 self._process_current_handle() 

298 Platform.current_grad_handle = handle 

299 Platform.post_grad_handle_process = post_process 

300 

301 def wait_grad_handle(self): 

302 """wait grad handle""" 

303 if Platform.current_grad_handle is None: 

304 return 

305 if Platform.grad_sync_stream is None: 

306 Platform.grad_sync_stream = self.new_stream() 

307 stream_context = self.get_stream_context() 

308 with stream_context(Platform.grad_sync_stream): 

309 self._process_current_handle() 

310 sync_event = Platform.grad_sync_stream.record_event() 

311 sync_event.wait() 

312 Platform.current_grad_handle = None 

313 Platform.post_grad_handle_process = None 

314 

315 @staticmethod 

316 def all_gather_object(object_list, obj, group=None) -> None: 

317 """ 

318 Aggregates all Python objects objs in a specified communication group into object_list. 

319 """ 

320 raise NotImplementedError("Platform subclasses must implement all_gather_object") 

321 

322 @staticmethod 

323 def init_process_group( 

324 backend: Optional[str] = None, 

325 *, 

326 init_method: Optional[str] = None, 

327 timeout: Optional[timedelta] = None, 

328 world_size: int = -1, 

329 rank: int = -1, 

330 store: Any = None, 

331 pg_options: Any = None, 

332 device_id: Any = None 

333 ) -> None: 

334 """ 

335 Initialize the default distributed process group. 

336 

337 Args: 

338 backend: The backend to use for distributed communication 

339 init_method: URL specifying how to initialize the process group 

340 timeout: Timeout for operations executed against the process group 

341 world_size: Number of processes participating in the job 

342 rank: Rank of the current process 

343 store: Key/value store for exchanging connection information 

344 pg_options: Process group options for backend-specific configurations 

345 device_id: Specific device this process will work on 

346 

347 Raises: 

348 NotImplementedError: This method must be implemented by subclasses 

349 """ 

350 raise NotImplementedError("Platform subclasses must implement init_process_group") 

351 

352 @staticmethod 

353 def destroy_process_group(group=None) -> None: 

354 """ 

355 Destroy a given process group. 

356 

357 Args: 

358 group: The process group to be destroyed. If None, destroys the default group. 

359 

360 Raises: 

361 NotImplementedError: This method must be implemented by subclasses 

362 """ 

363 raise NotImplementedError("Platform subclasses must implement destroy_process_group") 

364 

365 @staticmethod 

366 def get_process_group_ranks(group=None) -> list[int]: 

367 """ 

368 Get rank list of the given process group. 

369 

370 Args: 

371 group: The process group to get ranks from. If None, uses the default group. 

372 

373 Returns: 

374 List of ranks in the specified process group. 

375 

376 Raises: 

377 NotImplementedError: This method must be implemented by subclasses 

378 """ 

379 raise NotImplementedError("Platform subclasses must implement get_process_group_ranks") 

380 

381 @staticmethod 

382 def get_backend(group=None): 

383 """ 

384 Get the backend of the given process group. 

385 Args: 

386 group: The process group to get backend from. If None, uses the default group. 

387 

388 Returns: 

389 The backend name of the specified process group. 

390 

391 Raises: 

392 NotImplementedError: This method must be implemented by subclasses 

393 """ 

394 raise NotImplementedError("Platform subclasses must implement get_backend") 

395 

396 @staticmethod 

397 def split_group(parent_pg: Any = None, 

398 split_ranks: Optional[list] = None, 

399 timeout: Optional[timedelta] = None, 

400 pg_options: Optional[Any] = None, 

401 group_desc: Optional[str] = None, 

402 ) -> Any: 

403 """ 

404 Create split group relative to the parent process group. 

405 """ 

406 raise NotImplementedError("Platform subclasses must implement split_group") 

407 

408 @staticmethod 

409 def no_grad(): 

410 raise NotImplementedError("Platform subclasses must implement no_grad") 

411 

412 @staticmethod 

413 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False): 

414 raise NotImplementedError("Platform subclasses must implement empty_like") 

415 

416 def get_current_stream(self): 

417 raise NotImplementedError("Platform subclasses must implement get_current_stream") 

418 

419 def new_event(self): 

420 raise NotImplementedError("Platform subclasses must implement new_event") 

421 

422 def tree_map(self, fn, tree): 

423 raise NotImplementedError("Platform subclasses must implement tree_map") 

424 

425 @staticmethod 

426 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False): 

427 return module.register_forward_pre_hook(hook, prepend=prepend, with_kwargs=with_kwargs) 

428 

429 @staticmethod 

430 def register_full_backward_hook(module, hook, prepend=False): 

431 return module.register_full_backward_hook(hook, prepend) 

432 

433 @staticmethod 

434 def register_full_backward_pre_hook(module, hook, prepend=False): 

435 return module.register_full_backward_pre_hook(hook, prepend) 

436 

437 @property 

438 def checkpoint(self): 

439 raise NotImplementedError("Platform subclasses must implement checkpoint") 

440 

441 @staticmethod 

442 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs): 

443 raise NotImplementedError("Platform subclasses must implement ckpt_wrapper") 

444 

445 @property 

446 def noop_context_fn(self): 

447 raise NotImplementedError("Platform subclasses must implement ckpt_wrapper") 

448 

449 @staticmethod 

450 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): 

451 raise NotImplementedError("Platform subclasses must implement create_selective_checkpoint_contexts") 

452 

453 @staticmethod 

454 def async_save_on_cpu(policy_fn=None): 

455 raise NotImplementedError("Platform subclasses must implement async_save_on_cpu") 

456 

457 @staticmethod 

458 def tensor_to_numpy(tensor) -> np.ndarray: 

459 raise NotImplementedError("Platform subclasses must implement tensor_to_numpy") 

460 

461 def cast_fp_tensor(self, dtype, x): 

462 """ 

463 Cast floating-point tensor to target dtype if applicable. 

464 """ 

465 raise NotImplementedError("Platform subclasses must implement cast_fp_tensor") 

466 

467 def apply_to_tensors(self, fn, container): 

468 """Recursively apply to all tensor in different kinds of container types.""" 

469 raise NotImplementedError("Platform subclasses must implement apply_to_tensors")