Coverage for hyper_parallel / core / checkpoint / reshard.py: 95%
159 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""resharding tensor"""
16import operator
17from typing import Any, Optional, Union
18from functools import reduce
19import numpy as np
22def check_layout(layout: Optional[Any], name: str) -> None:
23 """
24 Validates that a layout contains required attributes with correct types.
26 Args:
27 layout (Optional[Any]): Layout object to validate.
28 name (str): Name of the layout (for error messages).
30 Raises:
31 ValueError: If layout missing required attributes or has size mismatches
32 TypeError: If layout components are not tuples/lists
33 """
34 if not layout:
35 return
37 # Check for required attributes
38 required_attrs = ['mesh_shape', '_tensor_map', '_rank_list']
39 for attr in required_attrs:
40 if not hasattr(layout, attr):
41 raise ValueError(
42 f"Layout {name} must contain attribute {attr}"
43 )
45 # Validate component types
46 def check_type_is_sequence(obj: Any, obj_name: str) -> None:
47 if not isinstance(obj, (tuple, list)):
48 raise TypeError(
49 f"Layout {name} {obj_name} must be tuple or list, "
50 f"but got {type(obj).__name__}"
51 )
53 layout_dict = layout.to_dict()
54 check_type_is_sequence(layout_dict['mesh_shape'], 'mesh_shape')
55 check_type_is_sequence(layout_dict['tensor_map'], 'tensor_map')
56 check_type_is_sequence(layout_dict['rank_list'], 'rank_list')
58 # Validate rank list size matches device count
59 dev_num = reduce(operator.mul, layout_dict['mesh_shape'])
60 if len(layout_dict['rank_list']) != dev_num:
61 raise ValueError(
62 f"Layout {name} rank_list size ({len(layout_dict['rank_list'])}) "
63 f"must match device count ({dev_num})"
64 )
67def rank_id_to_dev_id_list(mesh_shape: tuple[int, ...], rank_id: int) -> list[int]:
68 """
69 Converts a rank ID to a list of device IDs based on the mesh shape.
71 Args:
72 mesh_shape (tuple[int, ...]): Shape of the mesh shape.
73 rank_id (int): Global rank ID to convert.
75 Returns:
76 list[int]: List of device IDs corresponding to the rank.
77 """
78 dims = len(mesh_shape)
79 dev_id_list = [0] * dims
81 for i in range(dims - 1, -1, -1):
82 dev_id_list[i] = rank_id % mesh_shape[i]
83 rank_id = rank_id // mesh_shape[i]
85 return dev_id_list
88def infer_intersection(
89 area_a: tuple[tuple[int, int], ...],
90 area_b: tuple[tuple[int, int], ...]
91) -> Optional[tuple[tuple[int, int], ...]]:
92 """
93 Calculates the intersection of two tensor slice areas.
95 Args:
96 area_a (tuple[tuple[int, int], ...]): First area to intersect.
97 area_b (tuple[tuple[int, int], ...]): Second area to intersect.
99 Returns:
100 Optional[tuple[tuple[int, int], ...]]: Tuple of intersection boundaries or None if no intersection.
101 """
102 # Validate input formats
103 def is_valid_axis_list(axis_list: Any) -> None:
104 if not isinstance(axis_list, (tuple, list)):
105 raise TypeError("Area must be a tuple of ranges")
106 for axis_range in axis_list:
107 if (not isinstance(axis_range, (tuple, list)) \
108 or len(axis_range) != 2):
109 raise TypeError("Each axis range must be a 2-element tuple")
111 is_valid_axis_list(area_a)
112 is_valid_axis_list(area_b)
114 # Check dimension compatibility
115 if len(area_a) != len(area_b):
116 raise ValueError(
117 f"Area dimension mismatch: {len(area_a)} vs {len(area_b)}"
118 )
120 # Calculate intersection for each dimension
121 intersection: list[tuple[int, int]] = []
122 for axis_range_a, axis_range_b in zip(area_a, area_b):
123 left = max(axis_range_a[0], axis_range_b[0])
124 right = min(axis_range_a[1], axis_range_b[1])
126 if left >= right: # No intersection in this dimension
127 return None
129 intersection.append((left, right))
131 return tuple(intersection)
134def infer_slice_area_by_rank(
135 mesh_shape: tuple[int, ...],
136 tensor_map: Union[list[int], tuple[int, ...]],
137 rank_id: int,
138 full_shape: tuple[int, ...]
139) -> tuple[tuple[int, int], ...]:
140 """
141 Calculates the tensor slice boundaries for a specific rank.
143 Args:
144 mesh_shape (tuple[int, ...]): Shape of the mesh shape.
145 tensor_map (Union[list[int], tuple[int, ...]]): Mapping of tensor dimensions to device dimensions.
146 rank_id (int): Rank ID to calculate slice for.
147 full_shape (tuple[int, ...]): Complete shape of the original tensor.
149 Returns:
150 tuple[tuple[int, int], ...]: Tuple of (start, end) boundaries for each tensor dimension.
151 """
152 # Helper to get device count along a dimension
153 def _get_dev_num_along_dim(dim: int) -> int:
154 return mesh_shape[-dim - 1] if dim != -1 else 1
156 dims = len(full_shape)
157 dev_id_list = rank_id_to_dev_id_list(mesh_shape, rank_id)
158 area: list[tuple[int, int]] = []
160 for axis in range(dims):
161 mapping = tensor_map[axis]
162 if isinstance(mapping, int):
163 mapping = (mapping,) # Convert to tuple for consistent handling
165 # Calculate total number of splits for this axis
166 split_num = 1
167 for dim in mapping:
168 split_num *= _get_dev_num_along_dim(dim)
170 # Calculate slice ID for this rank
171 slice_id = 0
172 coef = 1
173 for dim in reversed(mapping):
174 if dim == -1:
175 continue
176 slice_id += dev_id_list[-dim - 1] * coef
177 coef *= _get_dev_num_along_dim(dim)
179 # Calculate start/end indices for this slice
180 if full_shape[axis] % split_num != 0:
181 raise ValueError(f"Shape can not divided along dimension {axis} by {split_num} dev.")
182 slice_size = full_shape[axis] // split_num
183 start = slice_id * slice_size
184 end = start + slice_size
185 area.append((start, end))
187 return tuple(area)
190class ReshardHandler:
191 """
192 Handles tensor resharding between different distributed layouts.
194 This class manages the process of reshaping and redistributing tensors between
195 different parallel layouts. It calculates necessary tensor slices, validates
196 input layouts, and assembles the final tensor for the target rank.
198 Args:
199 param_name (str): Name of the parameter (without pipeline stage prefix).
200 full_shape (tuple[int, ...]): Complete shape of the tensor before sharding.
201 from_layout (Optional[Any]): Source layout containing mesh shape, tensor map, and rank list.
202 to_layout (Optional[Any]): Target layout containing mesh shape, tensor map, and rank list.
203 to_rank_id (int): Target rank ID to receive the resharded tensor.
205 Raises:
206 ValueError: If both layouts are None or layouts contain invalid attributes
207 TypeError: If layout components are not tuples/lists
208 """
209 def __init__(
210 self,
211 param_name: str,
212 full_shape: tuple[int, ...],
213 from_layout: Optional[Any],
214 to_layout: Optional[Any],
215 to_rank_id: int
216 ):
217 # Validate input layouts
218 check_layout(from_layout, 'from_layout')
219 check_layout(to_layout, 'to_layout')
221 if from_layout is None and to_layout is None:
222 raise ValueError("`from_layout` and `to_layout` cannot both be None.")
224 # Initialize basic attributes
225 self.param_name = param_name
226 self.full_shape = full_shape
228 # Process source layout configuration
229 if from_layout is None:
230 self.from_mesh_shape = (1,)
231 self.from_tensor_map = tuple(0 for _ in full_shape)
232 self.from_rank_list = [0]
233 else:
234 from_layout_dict = from_layout.to_dict()
235 self.from_mesh_shape = from_layout_dict["mesh_shape"]
236 self.from_tensor_map = from_layout_dict["tensor_map"]
237 self.from_rank_list = from_layout_dict["rank_list"]
239 # Process target layout configuration
240 if to_layout is None:
241 self.to_mesh_shape = (1,)
242 self.to_tensor_map = tuple(0 for _ in full_shape)
243 self.to_rank_list = [0]
244 self.to_rank_id = 0
245 else:
246 to_layout_dict = to_layout.to_dict()
247 self.to_mesh_shape = to_layout_dict["mesh_shape"]
248 self.to_tensor_map = to_layout_dict["tensor_map"]
249 self.to_rank_list = to_layout_dict["rank_list"]
250 self.to_rank_id = to_rank_id
251 if self.to_rank_id not in self.to_rank_list:
252 raise ValueError("Input to_rank_id is not in to_rank_list.")
254 # Calculate device counts and internal rank mappings
255 self.from_dev_num = len(self.from_rank_list)
256 self.inner_from_rank_list = range(self.from_dev_num)
257 self.inner_to_rank_id = self.to_rank_list.index(self.to_rank_id)
259 # Compute redundancy information
260 self.inner_deredundancy_from_rank_list = (
261 self._infer_inner_deredundancy_rank_list_by_from_layout()
262 if from_layout else [0]
263 )
264 self.global_union_area_map: dict[int, tuple[tuple[int, int], ...]] = {}
266 def _infer_inner_deredundancy_rank_list_by_from_layout(self) -> list[int]:
267 """
268 Infers ranks containing non-redundant data from the source layout.
270 Returns:
271 List of ranks with unique data slices
272 """
273 inner_deredundancy_rank_list: list[int] = []
274 dev_dim = len(self.from_mesh_shape)
276 # Collect relevant device dimensions from tensor map
277 from_dev_map = set()
278 for map_dev in self.from_tensor_map:
279 if isinstance(map_dev, (list, tuple)):
280 for map_dev_inner in map_dev:
281 from_dev_map.add(dev_dim - map_dev_inner - 1)
282 else:
283 from_dev_map.add(dev_dim - map_dev - 1)
285 # Filter ranks with non-redundant data
286 unused_dims = [dim for dim in range(dev_dim) if dim not in from_dev_map]
287 if not unused_dims:
288 return list(self.inner_from_rank_list)
289 for rank_id in self.inner_from_rank_list:
290 dev_id_list = rank_id_to_dev_id_list(self.from_mesh_shape, rank_id)
291 # check redundant
292 found_redundant = False
293 for dim in unused_dims:
294 if dev_id_list[dim] > 0:
295 found_redundant = True
296 break
298 # save not redundant rank
299 if not found_redundant:
300 inner_deredundancy_rank_list.append(rank_id)
302 return inner_deredundancy_rank_list
304 def infer_all_tensor_offset(self) -> dict[int, tuple[tuple[int, int], ...]]:
305 """
306 Calculates required tensor slices from each source rank.
308 Determines which parts of the tensor need to be collected from each source
309 rank to assemble the target tensor slice.
311 Returns:
312 Dictionary mapping source ranks to their required slice offsets
313 """
314 # Calculate target area for current rank
315 self.to_area = infer_slice_area_by_rank(
316 self.to_mesh_shape,
317 self.to_tensor_map,
318 self.inner_to_rank_id,
319 self.full_shape
320 )
322 # Calculate required slices from each source rank
323 local_union_areas_map: dict[int, tuple[tuple[int, int], ...]] = {}
324 self.global_union_area_map.clear()
326 for inner_rank_id in self.inner_deredundancy_from_rank_list:
327 # Get source area for this rank
328 from_area = infer_slice_area_by_rank(
329 self.from_mesh_shape,
330 self.from_tensor_map,
331 inner_rank_id,
332 self.full_shape
333 )
335 # Find overlapping area between source and target
336 union_area = infer_intersection(from_area, self.to_area)
337 if union_area is not None:
338 source_rank = self.from_rank_list[inner_rank_id]
339 self.global_union_area_map[source_rank] = union_area
341 # Calculate relative offsets within source slice
342 local_union_areas_map[source_rank] = tuple(
343 (union_range[0] - from_range[0], union_range[1] - from_range[0])
344 for union_range, from_range in zip(union_area, from_area)
345 )
347 return local_union_areas_map
349 def get_real_tensor(self, from_tensor_map: dict[int, np.ndarray]) -> np.ndarray:
350 """
351 Assembles the final tensor for the target rank from collected slices.
353 Args:
354 from_tensor_map (dict[int, np.ndarray]): Dictionary mapping source ranks to their tensor slices.
356 Returns:
357 np.ndarray: Assembled tensor for the target rank.
359 Raises:
360 ValueError: If input slices are missing or have incorrect shapes
361 """
362 if not from_tensor_map:
363 raise ValueError("Input from_tensor_map cannot be empty")
365 # Validate input slices
366 for from_rank_id, from_area in self.global_union_area_map.items():
367 if from_rank_id not in from_tensor_map:
368 raise ValueError(
369 f"Missing slice data from rank {from_rank_id}. "
370 "Please provide all required slices from infer_all_tensor_offset."
371 )
373 # Validate slice shape matches expected size
374 expected_shape = tuple(end - start for start, end in from_area)
375 actual_shape = from_tensor_map[from_rank_id].shape
376 if expected_shape != actual_shape:
377 raise ValueError(
378 f"Slice from rank {from_rank_id} has incorrect shape. "
379 f"Expected {expected_shape}, got {actual_shape}."
380 )
382 # Create target tensor and assign slices
383 to_slice_shape = [end - start for start, end in self.to_area]
384 dtype = next(iter(from_tensor_map.values())).dtype
385 real_tensor = np.zeros(to_slice_shape, dtype=dtype)
387 for from_rank_id, from_slice in from_tensor_map.items():
388 from_area = self.global_union_area_map[from_rank_id]
390 # Calculate assignment indices in target tensor
391 assign_slices = tuple(
392 slice(from_axis[0] - to_axis[0], from_axis[1] - to_axis[0])
393 for from_axis, to_axis in zip(from_area, self.to_area)
394 )
396 real_tensor[assign_slices] = from_slice
398 return real_tensor