Coverage for hyper_parallel / platform / torch / pipeline_parallel / _utils.py: 71%
49 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""pipeline parallel utils"""
16import hyper_parallel
17from torch import nn
20class _MicroBatch(nn.Module):
21 """
22 Split inputs into micro_batch in pipeline parallel.
24 Args:
25 micro_batch_num (int): The number of micro-batch.
26 args_batch_dim (list, optional): Specify the batch dim of the args.
27 Default ``None``.
28 kwargs_batch_dim(dict, optional): Specify the batch dim of the kwargs.
29 Default ``None``.
30 Inputs:
31 - **args** (list) - Input args.
32 - **kwargs** (dict) - Input kwargs.
34 Outputs:
35 - **args_after_split** (list) - Input args after split into micro_batches.
36 - **kwargs_after_split** (list) - Input kwargs after split into micro_batches.
37 """
39 def __init__(self, micro_batch_num, args_batch_dim=None, kwargs_batch_dim=None):
40 super().__init__()
41 self.micro_batch_num = micro_batch_num
42 self.args_batch_dim = args_batch_dim
43 self.kwargs_batch_dim = kwargs_batch_dim
45 def forward(self, args, kwargs):
46 """forward of _MicroBatch"""
47 args_after_split = []
48 kwargs_after_split = []
49 for micro_idx in range(self.micro_batch_num):
50 micro_args = []
51 micro_kwargs = {}
52 for arg_idx, cur_arg in enumerate(args):
53 cur_arg_batch_dim = 0
54 if self.args_batch_dim and self.args_batch_dim[arg_idx] is not None:
55 cur_arg_batch_dim = self.args_batch_dim[arg_idx].batch_dim
56 if isinstance(cur_arg, hyper_parallel.DTensor):
57 micro_arg = self.split_inputs_with_custom_shard(cur_arg, cur_arg_batch_dim, micro_idx)
58 else:
59 micro_arg = self.split_inputs(cur_arg, cur_arg_batch_dim, micro_idx)
60 micro_args.append(micro_arg)
61 args_after_split.append(micro_args)
63 for key, cur_kwarg in kwargs.items():
64 cur_kwarg_batch_dim = 0
65 if self.kwargs_batch_dim is not None:
66 cur_kwarg_batch_dim = self.kwargs_batch_dim[key].batch_dim
67 if isinstance(cur_kwarg, hyper_parallel.DTensor):
68 micro_kwarg = self.split_inputs_with_custom_shard(cur_kwarg, cur_kwarg_batch_dim, micro_idx)
69 else:
70 micro_kwarg = self.split_inputs(cur_kwarg, cur_kwarg_batch_dim, micro_idx)
71 micro_kwargs[key] = micro_kwarg
72 kwargs_after_split.append(micro_kwargs)
73 return args_after_split, kwargs_after_split
75 def split_inputs_with_custom_shard(self, input_tensor, cur_arg_batch_dim, micro_idx):
76 input_layout = input_tensor.layout
77 func_wrap = hyper_parallel.custom_shard(self.split_inputs,
78 device_mesh=input_layout.mesh,
79 out_placements=(input_layout.placements,),
80 in_placements=(input_layout.placements, None, None)
81 )
82 return func_wrap(input_tensor, cur_arg_batch_dim, micro_idx)
84 def split_inputs(self, input_tensor, cur_arg_batch_dim, micro_idx):
85 """
86 Split the input along the specified batch_dim and micro_idx
87 """
88 if cur_arg_batch_dim == -1:
89 return input_tensor
90 batch_dim_shape = input_tensor.shape[cur_arg_batch_dim]
91 if batch_dim_shape % self.micro_batch_num != 0:
92 raise ValueError(f"Batch dimension size {batch_dim_shape} is not divisible by \
93 micro_batch_num {self.micro_batch_num}")
94 micro_batch_size = batch_dim_shape // self.micro_batch_num
96 # Calculate start and end idx
97 start = micro_batch_size * micro_idx
98 end = micro_batch_size * (micro_idx + 1)
100 # Create slicing tuple
101 slices = [slice(None)] * input_tensor.ndim
102 slices[cur_arg_batch_dim] = slice(start, end)
103 return input_tensor[slices]