Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / distributed_checkpoint / reshard.py: 8%
160 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""resharding tensor"""
16import operator
17from typing import Any, Optional, Union
18from functools import reduce
19import numpy as np
22def check_layout(layout: Optional[Any], name: str) -> None:
23 """
24 Validates that a layout contains required attributes with correct types.
26 Args:
27 layout (Optional[Any]): Layout object to validate.
28 name (str): Name of the layout (for error messages).
30 Raises:
31 ValueError: If layout missing required attributes or has size mismatches
32 TypeError: If layout components are not tuples/lists
33 """
34 if not layout:
35 return
37 # Check for required attributes
38 required_attrs = ['mesh_shape', '_tensor_map', '_rank_list']
39 for attr in required_attrs:
40 if not hasattr(layout, attr):
41 raise ValueError(
42 f"Layout {name} must contain attribute {attr}"
43 )
45 # Validate component types
46 def check_type_is_sequence(obj: Any, obj_name: str) -> None:
47 if not isinstance(obj, (tuple, list)):
48 raise TypeError(
49 f"Layout {name} {obj_name} must be tuple or list, "
50 f"but got {type(obj).__name__}"
51 )
53 layout_dict = layout.to_dict()
54 check_type_is_sequence(layout_dict['mesh_shape'], 'mesh_shape')
55 check_type_is_sequence(layout_dict['tensor_map'], 'tensor_map')
56 check_type_is_sequence(layout_dict['rank_list'], 'rank_list')
58 # Validate rank list size matches device count
59 dev_num = reduce(operator.mul, layout_dict['mesh_shape'])
60 if len(layout_dict['rank_list']) != dev_num:
61 raise ValueError(
62 f"Layout {name} rank_list size ({len(layout_dict['rank_list'])}) "
63 f"must match device count ({dev_num})"
64 )
67def rank_id_to_dev_id_list(mesh_shape: tuple[int, ...], rank_id: int) -> list[int]:
68 """
69 Converts a rank ID to a list of device IDs based on the mesh shape.
71 Args:
72 mesh_shape (tuple[int, ...]): Shape of the mesh shape.
73 rank_id (int): Global rank ID to convert.
75 Returns:
76 list[int]: List of device IDs corresponding to the rank.
77 """
78 dims = len(mesh_shape)
79 dev_id_list = [0] * dims
81 for i in range(dims - 1, -1, -1):
82 dev_id_list[i] = rank_id % mesh_shape[i]
83 rank_id = rank_id // mesh_shape[i]
85 return dev_id_list
88def infer_intersection(
89 area_a: tuple[tuple[int, int], ...],
90 area_b: tuple[tuple[int, int], ...]
91) -> Optional[tuple[tuple[int, int], ...]]:
92 """
93 Calculates the intersection of two tensor slice areas.
95 Args:
96 area_a (tuple[tuple[int, int], ...]): First area to intersect.
97 area_b (tuple[tuple[int, int], ...]): Second area to intersect.
99 Returns:
100 Optional[tuple[tuple[int, int], ...]]: Tuple of intersection boundaries or None if no intersection.
101 """
102 # Validate input formats
103 def is_valid_axis_list(axis_list: Any) -> None:
104 if not isinstance(axis_list, (tuple, list)):
105 raise TypeError("Area must be a tuple of ranges")
106 for axis_range in axis_list:
107 if (not isinstance(axis_range, (tuple, list)) \
108 or len(axis_range) != 2):
109 raise TypeError("Each axis range must be a 2-element tuple")
111 is_valid_axis_list(area_a)
112 is_valid_axis_list(area_b)
114 # Check dimension compatibility
115 if len(area_a) != len(area_b):
116 raise ValueError(
117 f"Area dimension mismatch: {len(area_a)} vs {len(area_b)}"
118 )
120 # Calculate intersection for each dimension
121 intersection: list[tuple[int, int]] = []
122 for axis_range_a, axis_range_b in zip(area_a, area_b):
123 left = max(axis_range_a[0], axis_range_b[0])
124 right = min(axis_range_a[1], axis_range_b[1])
126 if left >= right: # No intersection in this dimension
127 return None
129 intersection.append((left, right))
131 return tuple(intersection)
134def infer_slice_area_by_rank(
135 mesh_shape: tuple[int, ...],
136 tensor_map: Union[list[int], tuple[int, ...]],
137 rank_id: int,
138 full_shape: tuple[int, ...]
139) -> tuple[tuple[int, int], ...]:
140 """
141 Calculates the tensor slice boundaries for a specific rank.
143 Args:
144 mesh_shape (tuple[int, ...]): Shape of the mesh shape.
145 tensor_map (Union[list[int], tuple[int, ...]]): Mapping of tensor dimensions to device dimensions.
146 rank_id (int): Rank ID to calculate slice for.
147 full_shape (tuple[int, ...]): Complete shape of the original tensor.
149 Returns:
150 tuple[tuple[int, int], ...]: Tuple of (start, end) boundaries for each tensor dimension.
151 """
152 # Helper to get device count along a dimension
153 def _get_dev_num_along_dim(dim: int) -> int:
154 return mesh_shape[-dim - 1] if dim != -1 else 1
156 dims = len(full_shape)
157 dev_id_list = rank_id_to_dev_id_list(mesh_shape, rank_id)
158 area: list[tuple[int, int]] = []
160 for axis in range(dims):
161 mapping = tensor_map[axis]
162 if isinstance(mapping, int):
163 mapping = (mapping,) # Convert to tuple for consistent handling
165 # Calculate total number of splits for this axis
166 split_num = 1
167 for dim in mapping:
168 split_num *= _get_dev_num_along_dim(dim)
170 # Calculate slice ID for this rank
171 slice_id = 0
172 coef = 1
173 for dim in reversed(mapping):
174 if dim == -1:
175 continue
176 slice_id += dev_id_list[-dim - 1] * coef
177 coef *= _get_dev_num_along_dim(dim)
179 # Calculate start/end indices for this slice
180 if full_shape[axis] % split_num != 0:
181 raise ValueError(f"Shape can not divided along dimension {axis} by {split_num} dev.")
182 slice_size = full_shape[axis] // split_num
183 start = slice_id * slice_size
184 end = start + slice_size
185 area.append((start, end))
187 return tuple(area)
190class ReshardHandler:
191 """
192 Handles tensor resharding between different distributed layouts.
194 This class manages the process of reshaping and redistributing tensors between
195 different parallel layouts. It calculates necessary tensor slices, validates
196 input layouts, and assembles the final tensor for the target rank.
198 Args:
199 param_name (str): Name of the parameter (without pipeline stage prefix).
200 full_shape (tuple[int, ...]): Complete shape of the tensor before sharding.
201 from_layout (Optional[Any]): Source layout containing mesh shape, tensor map, and rank list.
202 to_layout (Optional[Any]): Target layout containing mesh shape, tensor map, and rank list.
203 to_rank_id (int): Target rank ID to receive the resharded tensor.
205 Raises:
206 ValueError: If both layouts are None or layouts contain invalid attributes
207 TypeError: If layout components are not tuples/lists
208 """
209 def __init__(
210 self,
211 param_name: str,
212 full_shape: tuple[int, ...],
213 from_layout: Optional[Any],
214 to_layout: Optional[Any],
215 to_rank_id: int
216 ):
217 # Validate input layouts
218 check_layout(from_layout, 'from_layout')
219 check_layout(to_layout, 'to_layout')
221 if from_layout is None and to_layout is None:
222 raise ValueError("`from_layout` and `to_layout` cannot both be None.")
224 # Initialize basic attributes
225 self.param_name = param_name
226 self.full_shape = full_shape
228 # Process source layout configuration
229 if from_layout is None:
230 self.from_mesh_shape = (1,)
231 self.from_tensor_map = tuple(0 for _ in full_shape)
232 self.from_rank_list = [0]
233 else:
234 from_layout_dict = from_layout.to_dict()
235 self.from_mesh_shape = from_layout_dict["mesh_shape"]
236 self.from_tensor_map = from_layout_dict["tensor_map"]
237 self.from_rank_list = from_layout_dict["rank_list"]
239 # Process target layout configuration
240 if to_layout is None:
241 self.to_mesh_shape = (1,)
242 self.to_tensor_map = tuple(0 for _ in full_shape)
243 self.to_rank_list = [0]
244 self.to_rank_id = 0
245 else:
246 to_layout_dict = to_layout.to_dict()
247 self.to_mesh_shape = to_layout_dict["mesh_shape"]
248 self.to_tensor_map = to_layout_dict["tensor_map"]
249 self.to_rank_list = to_layout_dict["rank_list"]
250 self.to_rank_id = to_rank_id
251 if self.to_rank_id not in self.to_rank_list:
252 raise ValueError("Input to_rank_id is not in to_rank_list.")
254 # Calculate device counts and internal rank mappings
255 self.from_dev_num = len(self.from_rank_list)
256 self.inner_from_rank_list = range(self.from_dev_num)
257 self.inner_to_rank_id = self.to_rank_list.index(self.to_rank_id)
259 # Compute redundancy information
260 self.inner_deredundancy_from_rank_list = (
261 self._infer_inner_deredundancy_rank_list_by_from_layout()
262 if from_layout else [0]
263 )
264 self.global_union_area_map: dict[int, tuple[tuple[int, int], ...]] = {}
265 self.to_area = () # Initialized in infer_all_tensor_offset()
267 def _infer_inner_deredundancy_rank_list_by_from_layout(self) -> list[int]:
268 """
269 Infers ranks containing non-redundant data from the source layout.
271 Returns:
272 List of ranks with unique data slices
273 """
274 inner_deredundancy_rank_list: list[int] = []
275 dev_dim = len(self.from_mesh_shape)
277 # Collect relevant device dimensions from tensor map
278 from_dev_map = set()
279 for map_dev in self.from_tensor_map:
280 if isinstance(map_dev, (list, tuple)):
281 for map_dev_inner in map_dev:
282 from_dev_map.add(dev_dim - map_dev_inner - 1)
283 else:
284 from_dev_map.add(dev_dim - map_dev - 1)
286 # Filter ranks with non-redundant data
287 unused_dims = [dim for dim in range(dev_dim) if dim not in from_dev_map]
288 if not unused_dims:
289 return list(self.inner_from_rank_list)
290 for rank_id in self.inner_from_rank_list:
291 dev_id_list = rank_id_to_dev_id_list(self.from_mesh_shape, rank_id)
292 # check redundant
293 found_redundant = False
294 for dim in unused_dims:
295 if dev_id_list[dim] > 0:
296 found_redundant = True
297 break
299 # save not redundant rank
300 if not found_redundant:
301 inner_deredundancy_rank_list.append(rank_id)
303 return inner_deredundancy_rank_list
305 def infer_all_tensor_offset(self) -> dict[int, tuple[tuple[int, int], ...]]:
306 """
307 Calculates required tensor slices from each source rank.
309 Determines which parts of the tensor need to be collected from each source
310 rank to assemble the target tensor slice.
312 Returns:
313 Dictionary mapping source ranks to their required slice offsets
314 """
315 # Calculate target area for current rank
316 self.to_area = infer_slice_area_by_rank(
317 self.to_mesh_shape,
318 self.to_tensor_map,
319 self.inner_to_rank_id,
320 self.full_shape
321 )
323 # Calculate required slices from each source rank
324 local_union_areas_map: dict[int, tuple[tuple[int, int], ...]] = {}
325 self.global_union_area_map.clear()
327 for inner_rank_id in self.inner_deredundancy_from_rank_list:
328 # Get source area for this rank
329 from_area = infer_slice_area_by_rank(
330 self.from_mesh_shape,
331 self.from_tensor_map,
332 inner_rank_id,
333 self.full_shape
334 )
336 # Find overlapping area between source and target
337 union_area = infer_intersection(from_area, self.to_area)
338 if union_area is not None:
339 source_rank = self.from_rank_list[inner_rank_id]
340 self.global_union_area_map[source_rank] = union_area
342 # Calculate relative offsets within source slice
343 local_union_areas_map[source_rank] = tuple(
344 (union_range[0] - from_range[0], union_range[1] - from_range[0])
345 for union_range, from_range in zip(union_area, from_area)
346 )
348 return local_union_areas_map
350 def get_real_tensor(self, from_tensor_map: dict[int, np.ndarray]) -> np.ndarray:
351 """
352 Assembles the final tensor for the target rank from collected slices.
354 Args:
355 from_tensor_map (dict[int, np.ndarray]): Dictionary mapping source ranks to their tensor slices.
357 Returns:
358 np.ndarray: Assembled tensor for the target rank.
360 Raises:
361 ValueError: If input slices are missing or have incorrect shapes
362 """
363 if not from_tensor_map:
364 raise ValueError("Input from_tensor_map cannot be empty")
366 # Validate input slices
367 for from_rank_id, from_area in self.global_union_area_map.items():
368 if from_rank_id not in from_tensor_map:
369 raise ValueError(
370 f"Missing slice data from rank {from_rank_id}. "
371 "Please provide all required slices from infer_all_tensor_offset."
372 )
374 # Validate slice shape matches expected size
375 expected_shape = tuple(end - start for start, end in from_area)
376 actual_shape = from_tensor_map[from_rank_id].shape
377 if expected_shape != actual_shape:
378 raise ValueError(
379 f"Slice from rank {from_rank_id} has incorrect shape. "
380 f"Expected {expected_shape}, got {actual_shape}."
381 )
383 # Create target tensor and assign slices
384 to_slice_shape = [end - start for start, end in self.to_area]
385 dtype = next(iter(from_tensor_map.values())).dtype
386 real_tensor = np.zeros(to_slice_shape, dtype=dtype)
388 for from_rank_id, from_slice in from_tensor_map.items():
389 from_area = self.global_union_area_map[from_rank_id]
391 # Calculate assignment indices in target tensor
392 assign_slices = tuple(
393 slice(from_axis[0] - to_axis[0], from_axis[1] - to_axis[0])
394 for from_axis, to_axis in zip(from_area, self.to_area)
395 )
397 real_tensor[assign_slices] = from_slice
399 return real_tensor