Coverage for hyper_parallel / platform / mindspore / pipeline_parallel / _utils.py: 0%
80 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# http://www.apache.org/licenses/LICENSE-2.0
2#
3# Unless required by applicable law or agreed to in writing, software
4# distributed under the License is distributed on an "AS IS" BASIS,
5# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6# See the License for the specific language governing permissions and
7# limitations under the License.
8# ============================================================================
9"""pipeline parallel utils"""
10import io
11import pickle
12import hyper_parallel
13from mindspore import nn, Tensor, mint, ops
14from mindspore.common import dtype as mstype
15from mindspore.communication import GlobalComm
16from mindspore.mint.distributed.distributed import _object_to_tensor, send, recv
17from hyper_parallel.core.shard.local_func import custom_shard
20class _MicroBatch(nn.Cell):
21 """
22 Split inputs into micro_batch in pipeline parallel.
24 Args:
25 micro_batch_num (int): The number of micro-batch.
26 args_batch_dim (list, optional): Specify the batch dim of the args.
27 Default ``None``.
28 kwargs_batch_dim(dict, optional): Specify the batch dim of the kwargs.
29 Default ``None``.
30 Inputs:
31 - **args** (list) - Input args.
32 - **kwargs** (dict) - Input kwargs.
34 Outputs:
35 - **args_after_split** (list) - Input args after split into micro_batches.
36 - **kwargs_after_split** (list) - Input kwargs after split into micro_batches.
37 """
39 def __init__(self, micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
40 super().__init__()
41 self.micro_batch_num = micro_batch_num
42 self.args_batch_dim = args_batch_dim
43 self.kwargs_batch_dim = kwargs_batch_dim
45 def construct(self, args, kwargs):
46 """Construct of _MicroBatch"""
47 args_after_split = []
48 kwargs_after_split = []
49 for micro_idx in range(self.micro_batch_num):
50 micro_args = []
51 micro_kwargs = {}
52 for arg_idx, cur_arg in enumerate(args):
53 cur_arg_batch_dim = 0
54 if self.args_batch_dim and self.args_batch_dim[arg_idx] is not None:
55 cur_arg_batch_dim = self.args_batch_dim[arg_idx].batch_dim
56 micro_arg = self.split_inputs_with_custom_shard(cur_arg, cur_arg_batch_dim, micro_idx)
57 micro_args.append(micro_arg)
58 args_after_split.append(micro_args)
60 for key, cur_kwarg in kwargs.items():
61 cur_kwarg_batch_dim = 0
62 if self.kwargs_batch_dim is not None:
63 cur_kwarg_batch_dim = self.kwargs_batch_dim[key].batch_dim
64 micro_kwarg = self.split_inputs_with_custom_shard(cur_kwarg, cur_kwarg_batch_dim, micro_idx)
65 micro_kwargs[key] = micro_kwarg
66 kwargs_after_split.append(micro_kwargs)
67 return args_after_split, kwargs_after_split
69 def split_inputs_with_custom_shard(self, input_tensor, cur_arg_batch_dim, micro_idx):
70 if not isinstance(input_tensor, hyper_parallel.DTensor):
71 raise TypeError(f"Input type {type(input_tensor)} is not DTensor.")
72 input_layout = input_tensor.layout
73 func_wrap = custom_shard(self.split_inputs,
74 device_mesh=input_layout.mesh,
75 out_placements=(input_layout.placements,),
76 in_placements=(input_layout.placements, None, None)
77 )
78 return func_wrap(input_tensor, cur_arg_batch_dim, micro_idx)
80 def split_inputs(self, input_tensor, cur_arg_batch_dim, micro_idx):
81 """
82 Split the input along the specified batch_dim and micro_idx
83 """
84 if cur_arg_batch_dim == -1:
85 return input_tensor
86 batch_dim_shape = input_tensor.shape[cur_arg_batch_dim]
87 micro_batch_begin = (batch_dim_shape // self.micro_batch_num) * micro_idx
88 micro_batch_end = (batch_dim_shape // self.micro_batch_num) * (micro_idx + 1)
89 strided_slice_begin = [0] * input_tensor.ndim
90 strided_slice_strides = [1] * input_tensor.ndim
91 strided_slice_end = list(input_tensor.shape)
92 strided_slice_begin[cur_arg_batch_dim] = micro_batch_begin
93 strided_slice_end[cur_arg_batch_dim] = micro_batch_end
94 micro_input = ops.strided_slice(input_tensor, strided_slice_begin, strided_slice_end, strided_slice_strides)
95 return micro_input
98def send_object_list(obj, dst=0, group=None):
99 """
100 Send the input Python object to dst rank.
102 Args:
103 obj (Any): The input tensor to be send.
104 dst (int, optional): Specifies the global rank that send the Python object to.
105 Default: ``0``.
106 group (str, optional): Communication group. Default: ``None``.
107 """
108 if group is None:
109 group = GlobalComm.WORLD_COMM_GROUP
110 if not isinstance(group, str):
111 raise TypeError(f"For 'send_object', the argument 'group' must be type of string, \
112 but got 'group' type : {type(group)}.")
113 if not isinstance(dst, int):
114 raise TypeError("For send_object, the dst must be int.")
115 obj_tensor, tensor_size = _object_to_tensor(obj)
116 obj_size = Tensor([tensor_size], dtype=mstype.int32)
117 send(obj_size, dst, group)
118 send(obj_tensor, dst, group)
121def recv_object_list(recv_obj, src=0, group=None):
122 """
123 receive Python object from src rank.
125 Args:
126 recv_obj (list): list to recv python objects.
127 src (int, optional): Specifies the global rank that receive the Python object.
128 Default: ``0`` .
129 group (str, optional): Communication group. Default: ``None``.
130 """
131 if group is None:
132 group = GlobalComm.WORLD_COMM_GROUP
133 if not isinstance(group, str):
134 raise TypeError(f"For 'recv_object', the argument 'group' must be type of string, \
135 but got 'group' type : {type(group)}.")
136 if not isinstance(src, int):
137 raise TypeError("For recv_object, the src must be int.")
138 obj_size = Tensor([0], dtype=mstype.int32)
139 recv(obj_size, src, group)
140 size_val = obj_size.item()
141 obj_tensor = mint.empty([size_val], dtype=mstype.int8)
142 recv(obj_tensor, src, group)
143 buf = obj_tensor.asnumpy().tobytes()[:size_val]
144 recv_obj.clear()
145 recv_obj.append(pickle.Unpickler(io.BytesIO(buf)).load()[0])