Coverage for hyper_parallel / core / pipeline_parallel / utils.py: 67%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""pipeline parallel utils""" 

16 

17 

18class BatchDimSpec: 

19 """ 

20 Specify the batch dimension of a Tensor. 

21 

22 Args: 

23 batch_dim (int): batch dimension. 

24 """ 

25 __slots__ = ("batch_dim",) 

26 

27 def __init__(self, batch_dim): 

28 if not isinstance(batch_dim, int): 

29 raise TypeError(f"batch_dim must be int, but got type {type(batch_dim)}.") 

30 self.batch_dim = batch_dim 

31 

32 def __repr__(self): 

33 return f"BatchDimSpec({self.batch_dim})" 

34 

35 def __str__(self): 

36 return f"BatchDim(dim={self.batch_dim})" 

37 

38 @staticmethod 

39 def from_tuple(batch_dims): 

40 if not isinstance(batch_dims, tuple): 

41 raise TypeError(f"batch_dims must be tuple, but got type {type(batch_dims)}.") 

42 return tuple(BatchDimSpec(dim) for dim in batch_dims) 

43 

44 @staticmethod 

45 def from_dict(batch_dims): 

46 if not isinstance(batch_dims, dict): 

47 raise TypeError(f"batch_dims must be dict, but got type {type(batch_dims)}.") 

48 return {k: BatchDimSpec(v) for k, v in batch_dims.items()} 

49 

50 

51class _RecvInfo: 

52 """ 

53 Used for construct forward Receive operation and backward Send operation. 

54 """ 

55 

56 def __init__(self, global_rank, buffer=None): 

57 self._global_rank = global_rank 

58 self._buffer = buffer 

59 

60 @property 

61 def global_rank(self): 

62 return self._global_rank 

63 

64 @property 

65 def buffer(self): 

66 return self._buffer 

67 

68 @buffer.setter 

69 def buffer(self, val): 

70 self._buffer = val