Coverage for hyper_parallel / core / pipeline_parallel / utils.py: 67%
33 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""pipeline parallel utils"""
18class BatchDimSpec:
19 """
20 Specify the batch dimension of a Tensor.
22 Args:
23 batch_dim (int): batch dimension.
24 """
25 __slots__ = ("batch_dim",)
27 def __init__(self, batch_dim):
28 if not isinstance(batch_dim, int):
29 raise TypeError(f"batch_dim must be int, but got type {type(batch_dim)}.")
30 self.batch_dim = batch_dim
32 def __repr__(self):
33 return f"BatchDimSpec({self.batch_dim})"
35 def __str__(self):
36 return f"BatchDim(dim={self.batch_dim})"
38 @staticmethod
39 def from_tuple(batch_dims):
40 if not isinstance(batch_dims, tuple):
41 raise TypeError(f"batch_dims must be tuple, but got type {type(batch_dims)}.")
42 return tuple(BatchDimSpec(dim) for dim in batch_dims)
44 @staticmethod
45 def from_dict(batch_dims):
46 if not isinstance(batch_dims, dict):
47 raise TypeError(f"batch_dims must be dict, but got type {type(batch_dims)}.")
48 return {k: BatchDimSpec(v) for k, v in batch_dims.items()}
51class _RecvInfo:
52 """
53 Used for construct forward Receive operation and backward Send operation.
54 """
56 def __init__(self, global_rank, buffer=None):
57 self._global_rank = global_rank
58 self._buffer = buffer
60 @property
61 def global_rank(self):
62 return self._global_rank
64 @property
65 def buffer(self):
66 return self._buffer
68 @buffer.setter
69 def buffer(self, val):
70 self._buffer = val