Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / api.py: 8%
310 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""shard"""
16import inspect
17from typing import Union, Callable, Dict, List
18from functools import wraps
19from hyper_parallel.core.dtensor.layout import Layout, DeviceMesh
20from hyper_parallel.core.dtensor.dtensor import DTensor, _is_alias_placements
21from hyper_parallel.core.dtensor.placement_types import Placement
22from hyper_parallel.core.shard.sharding_plan import ShardingPlan
23from hyper_parallel.platform import get_platform
25platform = get_platform()
26Parameter = platform.Parameter
27Tensor = platform.Tensor
28Module = platform.Module
31def _has_kwargs(func):
32 """_has_kwargs"""
33 sig = inspect.signature(func)
34 return any(
35 param.default != inspect.Parameter.empty
36 for param in sig.parameters.values()
37 )
40def _get_param_name(func):
41 """_get_param_name"""
42 sig = inspect.signature(func)
43 return list(sig.parameters.keys())
46def _convert_sharding_plan(sharding_plan: Dict, device_mesh: DeviceMesh) -> Dict:
47 """
48 Convert sharding_plan values to Layout objects.
50 This function recursively traverses the sharding_plan and converts
51 placement tuples (e.g., (Shard(0), Replicate())) to Layout objects.
53 Args:
54 sharding_plan: The original sharding plan with tuple specifications
55 device_mesh: The DeviceMesh to use for conversion
57 Returns:
58 Dict: Converted sharding plan with Layout objects
59 """
61 def _is_placement_tuple(value):
62 """Check if value is a placement specification tuple.
64 A placement tuple contains Placement instances (Shard, Replicate) or
65 alias strings ("dp", "None"). It should NOT be a tuple of placement tuples.
67 Examples of placement tuples:
68 (Shard(0), Replicate(), Shard(1)) -> True
69 ("dp", "tp", "None") -> True
70 (("dp", "tp"), "None") -> True (multi-axis sharding)
72 Examples of NON-placement tuples:
73 ((Shard(0),), (Shard(1),)) -> False (tuple of placement tuples)
74 """
75 if not isinstance(value, tuple) or len(value) == 0:
76 return False
78 for item in value:
79 # Placement instance is valid
80 if isinstance(item, Placement):
81 continue
82 # String (alias name) is valid
83 if isinstance(item, str):
84 continue
85 # Nested tuple needs special handling
86 if isinstance(item, tuple):
87 # Nested tuple of strings is valid (multi-axis sharding)
88 if len(item) > 0 and all(isinstance(x, str) for x in item):
89 continue
90 # Nested tuple containing Placement means this is a tuple of placement tuples
91 if len(item) > 0 and any(isinstance(x, Placement) for x in item):
92 return False
93 # Empty tuple or other cases - not valid
94 return False
95 # Any other type is not valid in a placement tuple
96 return False
98 return True
100 def _to_layout(value):
101 """Convert a single sharding specification to Layout."""
102 layout = Layout.from_device_mesh(device_mesh)
103 if _is_alias_placements(value):
104 result = layout(*value)
105 else:
106 result = layout(value)
107 return result
109 def _convert_value(value, wrap_single_as_list=False):
110 """Recursively convert value based on its structure."""
111 if value is None:
112 return None
114 # Case 1: It's a placement tuple - convert to Layout
115 if _is_placement_tuple(value):
116 layout = _to_layout(value)
117 # Wrap single layout in list if required (for input/output in forward)
118 return [layout] if wrap_single_as_list else layout
120 # Case 2: It's a dict - recursively process each value
121 if isinstance(value, dict):
122 converted_dict = {}
123 for k, v in value.items():
124 converted_dict[k] = _convert_value(v, wrap_single_as_list=False)
125 return converted_dict
127 # Case 3: It's a list - recursively process each element
128 if isinstance(value, list):
129 return [_convert_value(v, wrap_single_as_list=False) for v in value]
131 # Case 4: It's a tuple but not a placement tuple - treat as list
132 if isinstance(value, tuple):
133 return [_convert_value(v, wrap_single_as_list=False) for v in value]
135 # Case 5: Other types (e.g., primitives) - return as is
136 return value
138 def _convert_forward_plan(forward_plan):
139 """Convert forward plan with special handling for input/output."""
140 if forward_plan is None:
141 return None
143 converted = {}
144 for key, value in forward_plan.items():
145 if key.endswith("input") or key.endswith("output"):
146 # input/output need special handling:
147 # - dict format: convert each value, keep as dict
148 # - list/tuple format: convert each element, keep as list
149 # - single placement tuple: convert and wrap in list
150 if value is None:
151 converted[key] = None
152 elif isinstance(value, dict):
153 # Dict format for kwargs: {"x": placements, "activation": placements}
154 converted[key] = {k: _convert_value(v) for k, v in value.items()}
155 elif isinstance(value, (list, tuple)):
156 # Check if it's a single placement tuple or a list/tuple of placement tuples
157 if _is_placement_tuple(value):
158 # Single placement tuple - wrap in list
159 converted[key] = [_to_layout(value)]
160 else:
161 # List/tuple of placements for multiple positional args
162 converted[key] = [_convert_value(v) for v in value]
163 else:
164 converted[key] = _convert_value(value, wrap_single_as_list=True)
165 else:
166 # Other keys in forward plan
167 converted[key] = _convert_value(value)
168 return converted
170 # Main conversion logic
171 converted_plan = {}
173 for key, value in sharding_plan.items():
174 if key == "forward":
175 converted_plan[key] = _convert_forward_plan(value)
176 elif key.endswith("input") or key.endswith("output"):
177 # Top-level input/output (for callable sharding)
178 if value is None:
179 converted_plan[key] = None
180 elif isinstance(value, dict):
181 converted_plan[key] = {k: _convert_value(v) for k, v in value.items()}
182 elif isinstance(value, (list, tuple)):
183 if _is_placement_tuple(value):
184 converted_plan[key] = [_to_layout(value)]
185 else:
186 converted_plan[key] = [_convert_value(v) for v in value]
187 else:
188 converted_plan[key] = _convert_value(value, wrap_single_as_list=True)
189 else:
190 # parameter and other keys - use standard recursive conversion
191 converted_plan[key] = _convert_value(value)
193 return converted_plan
196def _parallel_in(func, args, kwargs, layouts):
197 """_parallel_in"""
198 if not isinstance(layouts, (list, dict, tuple)):
199 raise ValueError(f"The in_layout must be a list, tuple or dict, but got {type(layouts)}.")
201 params_name = _get_param_name(func)
202 processed_args = list(args)
203 processed_kwargs = dict(kwargs)
205 def _get_layout(index, is_list):
206 """_get_layout"""
207 if is_list:
208 return layouts[index]
209 param_name = params_name[index]
210 return layouts[param_name]
212 is_list = isinstance(layouts, (list, tuple))
213 for i, arg in enumerate(args):
214 if not isinstance(arg, DTensor):
215 continue
217 to_layout = _get_layout(i, is_list)
218 processed_args[i] = arg.redistribute(to_layout.mesh, to_layout.alias_placements)
219 for k, v in kwargs.items():
220 if not isinstance(v, DTensor) or layouts.get(k) is None:
221 processed_kwargs[k] = v
222 continue
223 to_layout = layouts[k]
224 processed_kwargs[k] = v.redistribute(to_layout.mesh, to_layout.alias_placements)
226 return tuple(processed_args), processed_kwargs
229def _parallel_out(outputs, layouts):
230 """_parallel_out"""
231 if not isinstance(layouts, (list, tuple)):
232 raise ValueError(f"The out_layout must be a list or tuple, but got {type(layouts)}.")
233 if isinstance(outputs, (tuple, list)):
234 if len(outputs) != len(layouts):
235 raise ValueError(f"The size of outputs and out_layout must be equal, but got {len(outputs)} and "
236 f"{len(layouts)}")
237 new_outputs = []
238 for i, arg in enumerate(outputs):
239 if not isinstance(arg, DTensor) or arg is None:
240 new_outputs.append(arg)
241 continue
242 to_layout = layouts[i]
243 new_outputs.append(arg.redistribute(to_layout.mesh, to_layout.alias_placements))
244 return tuple(new_outputs)
245 if len(layouts) != 1:
246 raise ValueError(f"The size of outputs and out_layout must be equal, but got 1 and "
247 f"{len(layouts)}")
249 if isinstance(outputs, DTensor):
250 return outputs.redistribute(
251 layouts[0].mesh, layouts[0].alias_placements)
252 return outputs
255def _forward_pre_hook(cell, args):
256 """_forward_pre_hook"""
257 if cell.in_layout is None:
258 return args
259 processed_args, _ = _parallel_in(platform.get_cell_construct(cell), args, {}, cell.in_layout)
260 return processed_args
263def _forward_pre_with_kwargs_hook(cell, args, kwargs):
264 """_forward_pre_with_kwargs_hook"""
265 if cell.in_layout is None:
266 return args, kwargs
267 return _parallel_in(platform.get_cell_construct(cell), args, kwargs, cell.in_layout)
270def _forward_hook(cell, inputs, outputs): # pylint: disable=unused-argument
271 """_forward_hook"""
272 if cell.out_layout is None:
273 return outputs
274 return _parallel_out(outputs, cell.out_layout)
277def _forward_with_kwargs_hook(cell, inputs, kwargs, outputs): # pylint: disable=unused-argument
278 """_forward_with_kwargs_hook"""
279 return _forward_hook(cell, inputs, outputs)
282def _register_hook(model: Module, sharding_plan: Dict):
283 """_register_hook"""
285 def _register_cell_hook(model, has_inputs_layout, has_outputs_layout):
286 """_register_cell_hook"""
287 has_kwargs = _has_kwargs(platform.get_cell_construct(model))
288 pre_hook = _forward_pre_with_kwargs_hook if has_kwargs else _forward_pre_hook
289 hook = _forward_with_kwargs_hook if has_kwargs else _forward_hook
290 if has_inputs_layout:
291 model.register_forward_pre_hook(pre_hook, with_kwargs=has_kwargs)
293 if has_outputs_layout:
294 model.register_forward_hook(hook, with_kwargs=has_kwargs)
296 def _set_layouts(model, layouts, set_inputs_layout, set_outputs_layout):
297 """_set_layouts"""
298 if set_inputs_layout:
299 model.in_layout = layouts
301 if set_outputs_layout:
302 model.out_layout = layouts
304 cell_dict = {}
305 for name, cell in platform.get_cells_and_names(model):
306 cell_dict[name] = cell
308 valid_suffix = ["input", "output"]
309 for key, value in sharding_plan.items():
310 if value is None:
311 continue
312 has_dot = '.' in key
313 split_key = key.rsplit('.', 1)
314 prefix = split_key[0] if has_dot else ""
315 suffix = split_key[1] if has_dot else key
316 if suffix not in valid_suffix:
317 raise ValueError(f"In python shard_module, sharding_plan's forward key must end with input or output, "
318 f"but got type {suffix}")
320 set_inputs_layout = suffix == "input"
321 set_outputs_layout = not set_inputs_layout
322 register_cell = cell_dict[prefix]
324 _set_layouts(register_cell, value, set_inputs_layout, set_outputs_layout)
325 _register_cell_hook(register_cell, set_inputs_layout, set_outputs_layout)
328def _register_local_tensor_hook(cell: Module, return_local_tensor_list: List[str]):
329 """_register_local_tensor_hook"""
331 def hook_func(cell, inputs, outputs): # pylint: disable=unused-argument
332 def _recursive_to_local(out):
333 if isinstance(out, (tuple, list)):
334 new_out = []
335 for item in out:
336 new_out.append(_recursive_to_local(item))
337 return tuple(new_out) if isinstance(out, tuple) else new_out
338 if isinstance(out, DTensor):
339 return out.to_local()
340 return out
342 return _recursive_to_local(outputs)
344 cell_dict = {}
345 for name, sub_cell in platform.get_cells_and_names(cell):
346 cell_dict[name] = sub_cell
348 for cell_name in return_local_tensor_list:
349 register_cell = cell_dict[cell_name]
350 register_cell.register_forward_hook(hook_func)
353def _shard_callable(func: Callable, sharding_plan: Dict):
354 """_shard_callable"""
355 forward_sharding_plan = sharding_plan.get("forward")
356 if forward_sharding_plan is None:
357 return func
359 @wraps(func)
360 def _shard_wrapper(*args, **kwargs):
361 """_shard_wrapper"""
362 input_layout = forward_sharding_plan.get("input")
363 output_layout = forward_sharding_plan.get("output")
364 if input_layout is not None:
365 args, kwargs = _parallel_in(func, args, kwargs, input_layout)
366 outputs = func(*args, **kwargs)
367 if output_layout is not None:
368 outputs = _parallel_out(outputs, output_layout)
369 return outputs
371 return _shard_wrapper
374def shard_module(model: Union[Module, Callable], device_mesh: DeviceMesh, sharding_plan: ShardingPlan):
375 """
376 Defining the input, output and parameters layouts of this cell or Callable.
378 Note:
379 - It is valid only in pynative mode.
381 .. warning::
382 The method is currently not supported in Graph mode.
384 Args:
385 model (Module or Callable): The model to be sharded.
386 device_mesh (DeviceMesh): The device mesh for sharding.
387 sharding_plan (ShardingPlan): Define the layout for the specified parameters, inputs or outputs.
388 The sharding specification can be:
389 - tuple of strings for alias format, e.g., ("dp", "None")
390 - tuple of Placements, e.g., (Shard(0), Replicate())
392 Returns:
393 Module or Callable: The sharded model.
395 Examples:
396 >>> # Usage with device_mesh and alias format
397 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp"))
398 >>> sharding_plan = ShardingPlan(
399 ... plan={"mlp.weight": ("None", "tp")},
400 ... input_plan={"input": ("dp", "None")},
401 ... output_plan={"output": ("dp", "tp")}
402 ... )
403 >>> model = shard_module(model, mesh, sharding_plan)
405 >>> # Usage with device_mesh and Placement format
406 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp"))
407 >>> sharding_plan = ShardingPlan(
408 ... plan={"mlp.weight": (Replicate(), Shard(1))},
409 ... input_plan={"input": (Shard(0), Replicate())},
410 ... output_plan={"output": (Shard(0), Shard(1))}
411 ... )
412 >>> model = shard_module(model, mesh, sharding_plan)
413 """
414 if platform.get_world_size() == 1:
415 return None
417 if not isinstance(sharding_plan, ShardingPlan):
418 raise TypeError(f"The 'sharding_plan' must be an instance of ShardingPlan, "
419 f"but got {type(sharding_plan)}. Direct dict input is not supported.")
421 normalized_plan = {}
422 return_local_tensor_list = None
424 if sharding_plan.plan:
425 normalized_plan["parameter"] = sharding_plan.plan
427 forward_part = {}
429 if sharding_plan.input_plan:
430 if not isinstance(sharding_plan.input_plan, dict):
431 raise TypeError(f"input_plan must be a dict, but got {type(sharding_plan.input_plan)}")
432 forward_part.update(sharding_plan.input_plan)
434 if sharding_plan.output_plan:
435 if not isinstance(sharding_plan.output_plan, dict):
436 raise TypeError(f"output_plan must be a dict, but got {type(sharding_plan.output_plan)}")
437 forward_part.update(sharding_plan.output_plan)
439 if forward_part:
440 normalized_plan["forward"] = forward_part
442 if sharding_plan.return_local_tensor:
443 return_local_tensor_list = sharding_plan.return_local_tensor
445 # Convert sharding_plan to Layout objects
446 converted_plan = _convert_sharding_plan(normalized_plan, device_mesh)
448 if not isinstance(model, Module):
449 return _shard_callable(model, converted_plan)
451 param_sharding_plan = converted_plan.get("parameter")
452 forward_sharding_plan = converted_plan.get("forward")
454 if param_sharding_plan is not None:
455 for param_name, layout in param_sharding_plan.items():
456 if not isinstance(layout, Layout):
457 raise ValueError(f"In python shard_module, the type of setting in parameter_plan must be Layout, "
458 f"but got type {type(layout)}")
459 result = platform.search_parameter_by_name(model, param_name)
460 if not result:
461 raise ValueError(f"{param_name} is configured with a layout, but no instance was found.")
462 _, _, param = result
463 if layout.tensor_map is None:
464 layout.placement_to_tensor_map(param.dim())
465 param = platform.set_layout_into_parameter(param, layout)
466 platform.update_parameter_by_name(model, result, param)
468 if forward_sharding_plan is not None:
469 _register_hook(model, forward_sharding_plan)
471 if return_local_tensor_list is not None:
472 _register_local_tensor_hook(model, return_local_tensor_list)
474 return model
477def parallelize_value_and_grad(fn, weights, sens=None):
478 """
479 A wrapper function to generate the function to calculate forward output and gradient for the parallel scenario.
481 Args:
482 fn (Union[Cell, Function]): Function to do grad operation.
483 weights (Union[ParameterTuple, Parameter, list[Parameter]]):
484 The parameters of the training network that need to
485 calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
486 sens (Union[list(float), tuple(float)], optional): The sensitivity for grad operation. Default: "None".
487 - If the fn only have one output, the sens must be None, and it will be attached automatically.
488 - If the fn have multiple outputs:
489 1) If the sens is None, only handle the first sensitivity, and set the remaining sensitivity to 0.
490 2) If the sens is not None, the lengths of sens and outputs of fn must be equal.
492 Returns:
493 Function, the derivative function used to compute the gradient of a given function.
494 For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like
495 `((out1, out2), gradient)` .
497 Raises:
498 TypeError: If type of Args does not belong to required ones.
500 Supported Platforms:
501 ``Ascend``
502 """
503 from mindspore import ops # pylint: disable=import-outside-toplevel
504 grad_fn = ops.GradOperation(get_by_list=True, sens_param=True)
506 # use CellWrapper to solve two problems:
507 # 1. avoid running the forward fn or cell twice
508 # 2. if the input of parallize_value_and_grad is cell and it is directly used as the input for grad,
509 # the operations before and after its __call__ function will not enter the auto-diff process.
510 class CellWrapper(Module):
511 def __init__(self, net):
512 super().__init__(auto_prefix=False)
513 self.network = net
515 def construct(self, *args, **kwargs):
516 return self.network(*args, **kwargs)
518 def forward(self, *args, **kwargs):
519 return self.network(*args, **kwargs)
521 fn = CellWrapper(fn)
522 fn.set_grad() # avoid running the forward fn or cell twice
524 def wrapper(*args, **kwargs):
525 loss_value = fn(*args, **kwargs)
526 p_sens = None
528 if isinstance(loss_value, (list, tuple)):
529 # There are multiple outputs, requiring multiple sens
530 p_sens = []
532 if sens is None:
533 # if sens is None, only handle the first sens, and set the remaining sens to 0
534 loss_0 = loss_value[0]
535 if isinstance(loss_0, DTensor):
536 repeat_num = loss_0.layout.repeat_num()
537 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.local_shape, 1.0 / repeat_num)
538 else:
539 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.shape, 1.0)
540 p_sens.append(sens_0)
542 for i in range(1, len(loss_value)):
543 loss_i = loss_value[i]
544 if isinstance(loss_i, DTensor):
545 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 0.0)
546 else:
547 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 0.0)
548 p_sens.append(sens_i)
550 else:
551 # sens is not None
552 if not isinstance(sens, list) and not isinstance(sens, tuple):
553 raise TypeError("if the loss is list or tuple, the sens must be None or list or tuple")
555 all_float = all(isinstance(item, float) for item in sens)
556 if not all_float:
557 raise TypeError("if sens is not None, it should be list of float or tuple of float")
559 if len(sens) != len(loss_value):
560 raise TypeError(f"the len of loss is {len(loss_value)}, but the len of sens is {len(sens)}")
562 for _, loss_i in enumerate(loss_value):
563 if isinstance(loss_i, DTensor):
564 repeat_num = loss_i.layout.repeat_num()
565 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 1.0 / repeat_num)
566 else:
567 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 1.0)
568 p_sens.append(sens_i)
570 else:
571 # loss is tensor
572 if sens is not None:
573 raise TypeError(f"the fn only have one output, the sens must be None, but it is {sens}")
574 if isinstance(loss_value, DTensor):
575 repeat_num = loss_value.layout.repeat_num()
576 p_sens = ops.fill(ops.DType()(loss_value), loss_value.local_shape, 1.0 / repeat_num)
578 else:
579 p_sens = ops.fill(ops.DType()(loss_value), loss_value.shape, 1.0)
581 grads = grad_fn(fn, weights)(*args, **kwargs, sens=p_sens)
582 return loss_value, grads
584 return wrapper