Coverage for hyper_parallel / collectives / cc.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed process group API""" 

16from datetime import timedelta 

17from typing import Optional, Any 

18 

19from hyper_parallel import get_platform 

20 

21platform = get_platform() 

22 

23 

24def init_process_group( 

25 backend: Optional[str] = None, 

26 *, 

27 init_method: Optional[str] = None, 

28 timeout: Optional[timedelta] = None, 

29 world_size: int = -1, 

30 rank: int = -1, 

31 store: Any = None, 

32 pg_options: Any = None, 

33 device_id: Any = None 

34) -> None: 

35 """ 

36 Init global process group, this is the start of distributed job. 

37 

38 Args: 

39 backend: The backend used for distributed communication. 

40 init_method: The method to initialize the process group. 

41 timeout: Timeout for operations executed against the process group. 

42 world_size: Number of processes participating in the job 

43 rank: Rank of the current process 

44 store: Key/value store for exchanging connection information 

45 pg_options: Process group options for backend-specific configurations 

46 device_id: Specific device this process will work on 

47 """ 

48 platform.init_process_group(backend=backend, init_method=init_method, timeout=timeout, world_size=world_size, 

49 rank=rank, store=store, pg_options=pg_options, device_id=device_id) 

50 

51 

52def destroy_process_group(group=None) -> None: 

53 """ 

54 Destroy a given process group. 

55 

56 Args: 

57 group: The process group to be destroyed. If None, destroys the default group. 

58 

59 Raises: 

60 NotImplementedError: This method must be implemented by subclasses 

61 """ 

62 platform.destroy_process_group(group=group) 

63 

64 

65def get_process_group_ranks(group=None) -> list[int]: 

66 """ 

67 Get rank list of the given process group. 

68 

69 Args: 

70 group: The process group to get ranks from. If None, uses the default group. 

71 

72 Returns: 

73 List of ranks in the specified process group. 

74 

75 Raises: 

76 NotImplementedError: This method must be implemented by subclasses 

77 """ 

78 return platform.get_process_group_ranks(group=group) 

79 

80 

81def get_backend(group=None): 

82 """ 

83 Get the backend of the given process group. 

84 Args: 

85 group: The process group to get backend from. If None, uses the default group. 

86 

87 Returns: 

88 The backend name of the specified process group. 

89 

90 Raises: 

91 NotImplementedError: This method must be implemented by subclasses 

92 """ 

93 return platform.get_backend(group=group) 

94 

95 

96def split_group(parent_pg: Any = None, 

97 split_ranks: Optional[list] = None, 

98 timeout: Optional[timedelta] = None, 

99 pg_options: Optional[Any] = None, 

100 group_desc: Optional[str] = None, 

101 ) -> Any: 

102 """ 

103 Create split group relative to the parent process group. 

104 """ 

105 return platform.split_group(parent_pg=parent_pg, split_ranks=split_ranks, timeout=timeout, pg_options=pg_options, 

106 group_desc=group_desc)