Coverage for hyper_parallel / platform / platform.py: 73%
259 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""framework platform api"""
16import os
17from datetime import timedelta
18from enum import auto, Enum
19from typing import Optional, Any
21import numpy as np
22# Environment variable name used to specify the AI framework platform to use
23HYPER_PARALLEL_PLATFORM = "HYPER_PARALLEL_PLATFORM"
25# Identifier for the MindSpore framework
26HYPER_PARALLEL_PLATFORM_MINDSPORE = "mindspore"
28# Identifier for the PyTorch framework
29HYPER_PARALLEL_PLATFORM_TORCH = "torch"
32class PlatformType(Enum):
33 """Enumeration class for AI framework platform types.
35 Used to identify different deep learning framework platform types.
36 """
37 MINDSPORE = auto()
38 PYTORCH = auto()
41# Global platform instance, used to cache the created platform object
42platform = None
45def get_mindspore_platform():
46 """Create mindspore platform"""
47 # pylint: disable=C0415
48 from hyper_parallel.platform.mindspore.platform import MindSporePlatform
49 global platform
50 platform = MindSporePlatform()
51 return platform
54def get_torch_platform():
55 """Create torch platform"""
56 # pylint: disable=C0415
57 from hyper_parallel.platform.torch.platform import TorchPlatform
58 global platform
59 platform = TorchPlatform()
60 return platform
63def get_platform():
64 """Obtain a framework platform instance.
66 Returns the appropriate AI framework platform instance based on environment variables or a default priority order.
67 The lookup priority is as follows:
68 1. Platform specified by environment variable
69 2. MindSpore platform (default preferred choice)
70 3. PyTorch platform (fallback option)
72 Returns:
73 Platform: An instance of the framework platform
75 Raises:
76 ImportError: Raised when none of the supported frameworks are available
77 """
78 if platform is not None:
79 return platform
80 platform_type = os.environ.get(HYPER_PARALLEL_PLATFORM)
81 if platform_type is not None and isinstance(platform_type, str):
82 platform_type = platform_type.lower()
83 if platform_type == HYPER_PARALLEL_PLATFORM_MINDSPORE:
84 return get_mindspore_platform()
85 if platform_type == HYPER_PARALLEL_PLATFORM_TORCH:
86 return get_torch_platform()
87 try:
88 return get_mindspore_platform()
89 except ImportError:
90 return get_torch_platform()
93EXISTING_COMM_GROUPS = {}
96class Platform:
97 """Platform api"""
98 current_grad_handle = None
99 post_grad_handle_process = None
100 grad_sync_stream = None
102 @staticmethod
103 def get_rank():
104 raise NotImplementedError("Platform subclasses must implement get_rank")
106 @staticmethod
107 def get_global_rank(group, group_rank):
108 raise NotImplementedError("Platform subclasses must implement get_global_rank")
110 @staticmethod
111 def get_world_size():
112 raise NotImplementedError("Platform subclasses must implement get_world_size")
114 @staticmethod
115 def get_op_name(func):
116 raise NotImplementedError("Platform subclasses must implement get_op_name")
118 @staticmethod
119 def differentiable_all_gather_concat(data, group, concat_size, concat_dim):
120 raise NotImplementedError("Platform subclasses must implement differentiable_all_gather_concat")
122 @staticmethod
123 def chunk(data, split_dim, split_size, index):
124 raise NotImplementedError("Platform subclasses must implement chunk")
126 @staticmethod
127 def differentiable_all_to_all(input_data, output_shape, group):
128 raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all")
130 @staticmethod
131 def tensor_type_cast(input_data, cast_type):
132 raise NotImplementedError("Platform subclasses must implement tensor_type_cast")
134 @staticmethod
135 def differentiable_all_reduce(data, op, group):
136 raise NotImplementedError("Platform subclasses must implement differentiable_all_reduce")
138 @staticmethod
139 def differentiable_reduce_scatter(data, dev_num, axis, op, group):
140 raise NotImplementedError("Platform subclasses must implement differentiable_reduce_scatter")
142 @staticmethod
143 def init_parameters(module, stage_index):
144 """platform ms need init parameter interface"""
145 if module is None:
146 raise ValueError("input module must not be none.")
147 if stage_index < 0:
148 raise ValueError("input stage_index must be positive.")
150 @staticmethod
151 def get_cell_construct(cell):
152 raise NotImplementedError("Platform subclasses must implement get_cell_construct")
154 @staticmethod
155 def get_cells_and_names(cell):
156 raise NotImplementedError("Platform subclasses must implement get_cells_and_names")
158 @staticmethod
159 def search_parameter_by_name(cell, param_name: str):
160 raise NotImplementedError("Platform subclasses must implement search_parameter_by_name")
162 @staticmethod
163 def update_parameter_by_name(cell, result: tuple, new_param) -> bool:
164 raise NotImplementedError("Platform subclasses must implement update_parameter_by_name")
166 @staticmethod
167 def set_layout_into_parameter(param, layout):
168 raise NotImplementedError("Platform subclasses must implement set_layout_into_parameter")
170 @staticmethod
171 def get_param_local_shape(param):
172 raise NotImplementedError("Platform subclasses must implement get_param_local_shape")
174 @staticmethod
175 def get_param_local_data(param):
176 raise NotImplementedError("Platform subclasses must implement get_param_local_data")
178 @staticmethod
179 def update_param_data(param, data):
180 raise NotImplementedError("Platform subclasses must implement update_param_data")
182 @staticmethod
183 def get_param_type_size(param):
184 raise NotImplementedError("Platform subclasses must implement get_param_type_size")
186 @staticmethod
187 def new_zero_parameter(param_shape, param_type, requires_grad, device):
188 raise NotImplementedError("Platform subclasses must implement new_zero_parameter")
190 @staticmethod
191 def new_tensor(tensor_shape, tensor_type, device):
192 raise NotImplementedError("Platform subclasses must implement new_tensor")
194 @staticmethod
195 def full_like(tensor, fill_value, dtype=None):
196 raise NotImplementedError("Platform subclasses must implement full_like")
198 @staticmethod
199 def set_tensor_requires_grad(input_tensor):
200 raise NotImplementedError("Platform subclasses must implement set_tensor_requires_grad")
202 @staticmethod
203 def all_gather_into_tensor(data, group_info, async_op=False):
204 raise NotImplementedError("Platform subclasses must implement all_gather_into_tensor")
206 @staticmethod
207 def all_reduce(data, group_info, async_op=False):
208 raise NotImplementedError("Platform subclasses must implement all_reduce")
210 @staticmethod
211 def broadcast(data, src, group, async_op=False):
212 raise NotImplementedError("Platform subclasses must implement broadcast")
214 @staticmethod
215 def isend(tensor, dst=None, group=None, tag=0):
216 raise NotImplementedError("Platform subclasses must implement isend")
218 @staticmethod
219 def irecv(tensor, src=None, group=None, tag=0):
220 raise NotImplementedError("Platform subclasses must implement irecv")
222 @staticmethod
223 def send_object_list(obj_list, dst=None, group=None):
224 raise NotImplementedError("Platform subclasses must implement send_object_list")
226 @staticmethod
227 def recv_object_list(obj_list, src=None, group=None):
228 raise NotImplementedError("Platform subclasses must implement send_object_list")
230 @staticmethod
231 def reduce_scatter_tensor(data, group_info, async_op=False):
232 raise NotImplementedError("Platform subclasses must implement reduce_scatter_tensor")
234 @staticmethod
235 def parameters_dict(cell):
236 raise NotImplementedError("Platform subclasses must implement parameters_dict")
238 @staticmethod
239 def save_checkpoint(cell, file_path: str) -> None:
240 raise NotImplementedError("Platform subclasses must implement save_checkpoint")
242 @staticmethod
243 def load_checkpoint(file_path: str) -> dict:
244 raise NotImplementedError("Platform subclasses must implement load_checkpoint")
246 def _create_group(self, rank_list, group_name=None):
247 raise NotImplementedError("Platform subclasses must implement _create_group")
249 def new_stream(self):
250 raise NotImplementedError("Platform subclasses must implement new_stream")
252 def get_stream_context(self):
253 raise NotImplementedError("Platform subclasses must implement get_stream_context")
255 @staticmethod
256 def get_tensor_transform():
257 raise NotImplementedError("Platform subclasses must implement get_tensor_transform")
259 @staticmethod
260 def construct_strided_slice(x, begin, end, stride):
261 raise NotImplementedError("Platform subclasses must implement construct_strided_slice")
263 @staticmethod
264 def micro_batch(micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
265 raise NotImplementedError("Platform subclasses must implement micro_batch")
267 def create_group(self, rank_list, group_name=None):
268 """create comm group with rank list"""
269 if group_name is None:
270 group_key = hash(tuple(rank_list))
271 else:
272 group_key = group_name
273 if group_key in EXISTING_COMM_GROUPS:
274 return EXISTING_COMM_GROUPS[group_key]
276 group = self._create_group(rank_list, group_name)
277 EXISTING_COMM_GROUPS[group_key] = group
278 return group
280 def _process_current_handle(self):
281 """wait current handle"""
282 if Platform.current_grad_handle is None:
283 return
285 Platform.current_grad_handle.wait()
286 if Platform.post_grad_handle_process is None:
287 return
288 # pylint: disable=E1102
289 Platform.post_grad_handle_process()
291 def set_grad_reduce_handle(self, handle, post_process=None):
292 """wait current handle and set new handle"""
293 if Platform.grad_sync_stream is None:
294 Platform.grad_sync_stream = self.new_stream()
295 stream_context = self.get_stream_context()
296 with stream_context(Platform.grad_sync_stream):
297 self._process_current_handle()
298 Platform.current_grad_handle = handle
299 Platform.post_grad_handle_process = post_process
301 def wait_grad_handle(self):
302 """wait grad handle"""
303 if Platform.current_grad_handle is None:
304 return
305 if Platform.grad_sync_stream is None:
306 Platform.grad_sync_stream = self.new_stream()
307 stream_context = self.get_stream_context()
308 with stream_context(Platform.grad_sync_stream):
309 self._process_current_handle()
310 sync_event = Platform.grad_sync_stream.record_event()
311 sync_event.wait()
312 Platform.current_grad_handle = None
313 Platform.post_grad_handle_process = None
315 @staticmethod
316 def all_gather_object(object_list, obj, group=None) -> None:
317 """
318 Aggregates all Python objects objs in a specified communication group into object_list.
319 """
320 raise NotImplementedError("Platform subclasses must implement all_gather_object")
322 @staticmethod
323 def init_process_group(
324 backend: Optional[str] = None,
325 *,
326 init_method: Optional[str] = None,
327 timeout: Optional[timedelta] = None,
328 world_size: int = -1,
329 rank: int = -1,
330 store: Any = None,
331 pg_options: Any = None,
332 device_id: Any = None
333 ) -> None:
334 """
335 Initialize the default distributed process group.
337 Args:
338 backend: The backend to use for distributed communication
339 init_method: URL specifying how to initialize the process group
340 timeout: Timeout for operations executed against the process group
341 world_size: Number of processes participating in the job
342 rank: Rank of the current process
343 store: Key/value store for exchanging connection information
344 pg_options: Process group options for backend-specific configurations
345 device_id: Specific device this process will work on
347 Raises:
348 NotImplementedError: This method must be implemented by subclasses
349 """
350 raise NotImplementedError("Platform subclasses must implement init_process_group")
352 @staticmethod
353 def destroy_process_group(group=None) -> None:
354 """
355 Destroy a given process group.
357 Args:
358 group: The process group to be destroyed. If None, destroys the default group.
360 Raises:
361 NotImplementedError: This method must be implemented by subclasses
362 """
363 raise NotImplementedError("Platform subclasses must implement destroy_process_group")
365 @staticmethod
366 def get_process_group_ranks(group=None) -> list[int]:
367 """
368 Get rank list of the given process group.
370 Args:
371 group: The process group to get ranks from. If None, uses the default group.
373 Returns:
374 List of ranks in the specified process group.
376 Raises:
377 NotImplementedError: This method must be implemented by subclasses
378 """
379 raise NotImplementedError("Platform subclasses must implement get_process_group_ranks")
381 @staticmethod
382 def get_backend(group=None):
383 """
384 Get the backend of the given process group.
385 Args:
386 group: The process group to get backend from. If None, uses the default group.
388 Returns:
389 The backend name of the specified process group.
391 Raises:
392 NotImplementedError: This method must be implemented by subclasses
393 """
394 raise NotImplementedError("Platform subclasses must implement get_backend")
396 @staticmethod
397 def split_group(parent_pg: Any = None,
398 split_ranks: Optional[list] = None,
399 timeout: Optional[timedelta] = None,
400 pg_options: Optional[Any] = None,
401 group_desc: Optional[str] = None,
402 ) -> Any:
403 """
404 Create split group relative to the parent process group.
405 """
406 raise NotImplementedError("Platform subclasses must implement split_group")
408 @staticmethod
409 def no_grad():
410 raise NotImplementedError("Platform subclasses must implement no_grad")
412 @staticmethod
413 def empty_like(tensor, *, dtype=None, device=None, pin_memory=False):
414 raise NotImplementedError("Platform subclasses must implement empty_like")
416 def get_current_stream(self):
417 raise NotImplementedError("Platform subclasses must implement get_current_stream")
419 def new_event(self):
420 raise NotImplementedError("Platform subclasses must implement new_event")
422 def tree_map(self, fn, tree):
423 raise NotImplementedError("Platform subclasses must implement tree_map")
425 @staticmethod
426 def register_forward_pre_hook(module, hook, prepend=False, with_kwargs=False):
427 return module.register_forward_pre_hook(hook, prepend=prepend, with_kwargs=with_kwargs)
429 @staticmethod
430 def register_full_backward_hook(module, hook, prepend=False):
431 return module.register_full_backward_hook(hook, prepend)
433 @staticmethod
434 def register_full_backward_pre_hook(module, hook, prepend=False):
435 return module.register_full_backward_pre_hook(hook, prepend)
437 @property
438 def checkpoint(self):
439 raise NotImplementedError("Platform subclasses must implement checkpoint")
441 @staticmethod
442 def ckpt_wrapper(module, checkpoint_fn=None, **checkpoint_fn_kwargs):
443 raise NotImplementedError("Platform subclasses must implement ckpt_wrapper")
445 @property
446 def noop_context_fn(self):
447 raise NotImplementedError("Platform subclasses must implement ckpt_wrapper")
449 @staticmethod
450 def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
451 raise NotImplementedError("Platform subclasses must implement create_selective_checkpoint_contexts")
453 @staticmethod
454 def async_save_on_cpu(policy_fn=None):
455 raise NotImplementedError("Platform subclasses must implement async_save_on_cpu")
457 @staticmethod
458 def tensor_to_numpy(tensor) -> np.ndarray:
459 raise NotImplementedError("Platform subclasses must implement tensor_to_numpy")
461 def cast_fp_tensor(self, dtype, x):
462 """
463 Cast floating-point tensor to target dtype if applicable.
464 """
465 raise NotImplementedError("Platform subclasses must implement cast_fp_tensor")
467 def apply_to_tensors(self, fn, container):
468 """Recursively apply to all tensor in different kinds of container types."""
469 raise NotImplementedError("Platform subclasses must implement apply_to_tensors")