Coverage for hyper_parallel / core / shard / api.py: 77%
305 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""shard"""
16import inspect
17from typing import Union, Callable, Dict, List
18from functools import wraps
19from hyper_parallel.core.layout import Layout, DeviceMesh
20from hyper_parallel.core.dtensor import DTensor
21from hyper_parallel.core.placement_types import Placement
22from hyper_parallel.core.shard.sharding_plan import ShardingPlan
23from hyper_parallel.platform import get_platform
25platform = get_platform()
26Parameter = platform.Parameter
27Tensor = platform.Tensor
28Module = platform.Module
31def _has_kwargs(func):
32 """_has_kwargs"""
33 sig = inspect.signature(func)
34 return any(
35 param.default != inspect.Parameter.empty
36 for param in sig.parameters.values()
37 )
40def _get_param_name(func):
41 """_get_param_name"""
42 sig = inspect.signature(func)
43 return list(sig.parameters.keys())
46def _convert_sharding_plan(sharding_plan: Dict, device_mesh: DeviceMesh) -> Dict:
47 """
48 Convert sharding_plan values to Layout objects.
50 This function recursively traverses the sharding_plan and converts
51 placement tuples (e.g., (Shard(0), Replicate())) to Layout objects.
53 Args:
54 sharding_plan: The original sharding plan with tuple specifications
55 device_mesh: The DeviceMesh to use for conversion
57 Returns:
58 Dict: Converted sharding plan with Layout objects
59 """
61 def _is_placement_tuple(value):
62 """Check if value is a placement specification tuple.
64 A placement tuple contains Placement instances (Shard, Replicate) or
65 alias strings ("dp", "None"). It should NOT be a tuple of placement tuples.
67 Examples of placement tuples:
68 (Shard(0), Replicate(), Shard(1)) -> True
69 ("dp", "tp", "None") -> True
70 (("dp", "tp"), "None") -> True (multi-axis sharding)
72 Examples of NON-placement tuples:
73 ((Shard(0),), (Shard(1),)) -> False (tuple of placement tuples)
74 """
75 if not isinstance(value, tuple) or len(value) == 0:
76 return False
78 for item in value:
79 # Placement instance is valid
80 if isinstance(item, Placement):
81 continue
82 # String (alias name) is valid
83 if isinstance(item, str):
84 continue
85 # Nested tuple needs special handling
86 if isinstance(item, tuple):
87 # Nested tuple of strings is valid (multi-axis sharding)
88 if len(item) > 0 and all(isinstance(x, str) for x in item):
89 continue
90 # Nested tuple containing Placement means this is a tuple of placement tuples
91 if len(item) > 0 and any(isinstance(x, Placement) for x in item):
92 return False
93 # Empty tuple or other cases - not valid
94 return False
95 # Any other type is not valid in a placement tuple
96 return False
98 return True
100 def _to_layout(value):
101 """Convert a single sharding specification to Layout."""
102 layout = Layout.from_device_mesh(device_mesh)
103 result = layout(value)
104 return result
106 def _convert_value(value, wrap_single_as_list=False):
107 """Recursively convert value based on its structure."""
108 if value is None:
109 return None
111 # Case 1: It's a placement tuple - convert to Layout
112 if _is_placement_tuple(value):
113 layout = _to_layout(value)
114 # Wrap single layout in list if required (for input/output in forward)
115 return [layout] if wrap_single_as_list else layout
117 # Case 2: It's a dict - recursively process each value
118 if isinstance(value, dict):
119 converted_dict = {}
120 for k, v in value.items():
121 converted_dict[k] = _convert_value(v, wrap_single_as_list=False)
122 return converted_dict
124 # Case 3: It's a list - recursively process each element
125 if isinstance(value, list):
126 return [_convert_value(v, wrap_single_as_list=False) for v in value]
128 # Case 4: It's a tuple but not a placement tuple - treat as list
129 if isinstance(value, tuple):
130 return [_convert_value(v, wrap_single_as_list=False) for v in value]
132 # Case 5: Other types (e.g., primitives) - return as is
133 return value
135 def _convert_forward_plan(forward_plan):
136 """Convert forward plan with special handling for input/output."""
137 if forward_plan is None:
138 return None
140 converted = {}
141 for key, value in forward_plan.items():
142 if key.endswith("input") or key.endswith("output"):
143 # input/output need special handling:
144 # - dict format: convert each value, keep as dict
145 # - list/tuple format: convert each element, keep as list
146 # - single placement tuple: convert and wrap in list
147 if value is None:
148 converted[key] = None
149 elif isinstance(value, dict):
150 # Dict format for kwargs: {"x": placements, "activation": placements}
151 converted[key] = {k: _convert_value(v) for k, v in value.items()}
152 elif isinstance(value, (list, tuple)):
153 # Check if it's a single placement tuple or a list/tuple of placement tuples
154 if _is_placement_tuple(value):
155 # Single placement tuple - wrap in list
156 converted[key] = [_to_layout(value)]
157 else:
158 # List/tuple of placements for multiple positional args
159 converted[key] = [_convert_value(v) for v in value]
160 else:
161 converted[key] = _convert_value(value, wrap_single_as_list=True)
162 else:
163 # Other keys in forward plan
164 converted[key] = _convert_value(value)
165 return converted
167 # Main conversion logic
168 converted_plan = {}
170 for key, value in sharding_plan.items():
171 if key == "forward":
172 converted_plan[key] = _convert_forward_plan(value)
173 elif key.endswith("input") or key.endswith("output"):
174 # Top-level input/output (for callable sharding)
175 if value is None:
176 converted_plan[key] = None
177 elif isinstance(value, dict):
178 converted_plan[key] = {k: _convert_value(v) for k, v in value.items()}
179 elif isinstance(value, (list, tuple)):
180 if _is_placement_tuple(value):
181 converted_plan[key] = [_to_layout(value)]
182 else:
183 converted_plan[key] = [_convert_value(v) for v in value]
184 else:
185 converted_plan[key] = _convert_value(value, wrap_single_as_list=True)
186 else:
187 # parameter and other keys - use standard recursive conversion
188 converted_plan[key] = _convert_value(value)
190 return converted_plan
193def _parallel_in(func, args, kwargs, layouts):
194 """_parallel_in"""
195 if not isinstance(layouts, (list, dict, tuple)):
196 raise ValueError(f"The in_layout must be a list, tuple or dict, but got {type(layouts)}.")
198 params_name = _get_param_name(func)
199 processed_args = list(args)
200 processed_kwargs = dict(kwargs)
202 def _get_layout(index, is_list):
203 """_get_layout"""
204 if is_list:
205 return layouts[index]
206 param_name = params_name[index]
207 return layouts[param_name]
209 is_list = isinstance(layouts, (list, tuple))
210 for i, arg in enumerate(args):
211 if not isinstance(arg, DTensor):
212 continue
214 to_layout = _get_layout(i, is_list)
215 processed_args[i] = arg.redistribute(to_layout.mesh, to_layout.placements)
216 for k, v in kwargs.items():
217 if not isinstance(v, DTensor) or layouts.get(k) is None:
218 processed_kwargs[k] = v
219 continue
220 to_layout = layouts[k]
221 processed_kwargs[k] = v.redistribute(to_layout.mesh, to_layout.placements)
223 return tuple(processed_args), processed_kwargs
226def _parallel_out(outputs, layouts):
227 """_parallel_out"""
228 if not isinstance(layouts, (list, tuple)):
229 raise ValueError(f"The out_layout must be a list or tuple, but got {type(layouts)}.")
230 if isinstance(outputs, (tuple, list)):
231 if len(outputs) != len(layouts):
232 raise ValueError(f"The size of outputs and out_layout must be equal, but got {len(outputs)} and "
233 f"{len(layouts)}")
234 new_outputs = []
235 for i, arg in enumerate(outputs):
236 if not isinstance(arg, DTensor) or arg is None:
237 new_outputs.append(arg)
238 continue
239 to_layout = layouts[i]
240 new_outputs.append(arg.redistribute(to_layout.mesh, to_layout.placements))
241 return tuple(new_outputs)
242 if len(layouts) != 1:
243 raise ValueError(f"The size of outputs and out_layout must be equal, but got 1 and "
244 f"{len(layouts)}")
246 return outputs.redistribute(layouts[0].mesh, layouts[0].placements) if isinstance(outputs, DTensor) else outputs
249def _forward_pre_hook(cell, args):
250 """_forward_pre_hook"""
251 if cell.in_layout is None:
252 return args
253 processed_args, _ = _parallel_in(platform.get_cell_construct(cell), args, {}, cell.in_layout)
254 return processed_args
257def _forward_pre_with_kwargs_hook(cell, args, kwargs):
258 """_forward_pre_with_kwargs_hook"""
259 if cell.in_layout is None:
260 return args, kwargs
261 return _parallel_in(platform.get_cell_construct(cell), args, kwargs, cell.in_layout)
264def _forward_hook(cell, inputs, outputs): # pylint: disable=unused-argument
265 """_forward_hook"""
266 if cell.out_layout is None:
267 return outputs
268 return _parallel_out(outputs, cell.out_layout)
271def _forward_with_kwargs_hook(cell, inputs, kwargs, outputs): # pylint: disable=unused-argument
272 """_forward_with_kwargs_hook"""
273 return _forward_hook(cell, inputs, outputs)
276def _register_hook(model: Module, sharding_plan: Dict):
277 """_register_hook"""
279 def _register_cell_hook(model, has_inputs_layout, has_outputs_layout):
280 """_register_cell_hook"""
281 has_kwargs = _has_kwargs(platform.get_cell_construct(model))
282 pre_hook = _forward_pre_with_kwargs_hook if has_kwargs else _forward_pre_hook
283 hook = _forward_with_kwargs_hook if has_kwargs else _forward_hook
284 if has_inputs_layout:
285 model.register_forward_pre_hook(pre_hook, with_kwargs=has_kwargs)
287 if has_outputs_layout:
288 model.register_forward_hook(hook, with_kwargs=has_kwargs)
290 def _set_layouts(model, layouts, set_inputs_layout, set_outputs_layout):
291 """_set_layouts"""
292 if set_inputs_layout:
293 model.in_layout = layouts
295 if set_outputs_layout:
296 model.out_layout = layouts
298 cell_dict = {}
299 for name, cell in platform.get_cells_and_names(model):
300 cell_dict[name] = cell
302 valid_suffix = ["input", "output"]
303 for key, value in sharding_plan.items():
304 if value is None:
305 continue
306 has_dot = '.' in key
307 split_key = key.rsplit('.', 1)
308 prefix = split_key[0] if has_dot else ""
309 suffix = split_key[1] if has_dot else key
310 if suffix not in valid_suffix:
311 raise ValueError(f"In python shard_module, sharding_plan's forward key must end with input or output, "
312 f"but got type {suffix}")
314 set_inputs_layout = suffix == "input"
315 set_outputs_layout = not set_inputs_layout
316 register_cell = cell_dict[prefix]
318 _set_layouts(register_cell, value, set_inputs_layout, set_outputs_layout)
319 _register_cell_hook(register_cell, set_inputs_layout, set_outputs_layout)
322def _register_local_tensor_hook(cell: Module, return_local_tensor_list: List[str]):
323 """_register_local_tensor_hook"""
325 def hook_func(cell, inputs, outputs): # pylint: disable=unused-argument
326 def _recursive_to_local(out):
327 if isinstance(out, (tuple, list)):
328 new_out = []
329 for item in out:
330 new_out.append(_recursive_to_local(item))
331 return tuple(new_out) if isinstance(out, tuple) else new_out
332 if isinstance(out, DTensor):
333 return out.to_local()
334 return out
336 return _recursive_to_local(outputs)
338 cell_dict = {}
339 for name, sub_cell in platform.get_cells_and_names(cell):
340 cell_dict[name] = sub_cell
342 for cell_name in return_local_tensor_list:
343 register_cell = cell_dict[cell_name]
344 register_cell.register_forward_hook(hook_func)
347def _shard_callable(func: Callable, sharding_plan: Dict):
348 """_shard_callable"""
349 forward_sharding_plan = sharding_plan.get("forward")
350 if forward_sharding_plan is None:
351 return func
353 @wraps(func)
354 def _shard_wrapper(*args, **kwargs):
355 """_shard_wrapper"""
356 input_layout = forward_sharding_plan.get("input")
357 output_layout = forward_sharding_plan.get("output")
358 if input_layout is not None:
359 args, kwargs = _parallel_in(func, args, kwargs, input_layout)
360 outputs = func(*args, **kwargs)
361 if output_layout is not None:
362 outputs = _parallel_out(outputs, output_layout)
363 return outputs
365 return _shard_wrapper
368def shard_module(model: Union[Module, Callable], device_mesh: DeviceMesh, sharding_plan: ShardingPlan):
369 """
370 Defining the input, output and parameters layouts of this cell or Callable.
372 Note:
373 - It is valid only in pynative mode.
375 .. warning::
376 The method is currently not supported in Graph mode.
378 Args:
379 model (Module or Callable): The model to be sharded.
380 device_mesh (DeviceMesh): The device mesh for sharding.
381 sharding_plan (ShardingPlan): Define the layout for the specified parameters, inputs or outputs.
382 The sharding specification can be:
383 - tuple of strings for alias format, e.g., ("dp", "None")
384 - tuple of Placements, e.g., (Shard(0), Replicate())
386 Returns:
387 Module or Callable: The sharded model.
389 Examples:
390 >>> # Usage with device_mesh and alias format
391 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp"))
392 >>> sharding_plan = ShardingPlan(
393 ... plan={"mlp.weight": ("None", "tp")},
394 ... input_plan={"input": ("dp", "None")},
395 ... output_plan={"output": ("dp", "tp")}
396 ... )
397 >>> model = shard_module(model, mesh, sharding_plan)
399 >>> # Usage with device_mesh and Placement format
400 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp"))
401 >>> sharding_plan = ShardingPlan(
402 ... plan={"mlp.weight": (Replicate(), Shard(1))},
403 ... input_plan={"input": (Shard(0), Replicate())},
404 ... output_plan={"output": (Shard(0), Shard(1))}
405 ... )
406 >>> model = shard_module(model, mesh, sharding_plan)
407 """
408 if platform.get_world_size() == 1:
409 return None
411 if not isinstance(sharding_plan, ShardingPlan):
412 raise TypeError(f"The 'sharding_plan' must be an instance of ShardingPlan, "
413 f"but got {type(sharding_plan)}. Direct dict input is not supported.")
415 normalized_plan = {}
416 return_local_tensor_list = None
418 if sharding_plan.plan:
419 normalized_plan["parameter"] = sharding_plan.plan
421 forward_part = {}
423 if sharding_plan.input_plan:
424 if not isinstance(sharding_plan.input_plan, dict):
425 raise TypeError(f"input_plan must be a dict, but got {type(sharding_plan.input_plan)}")
426 forward_part.update(sharding_plan.input_plan)
428 if sharding_plan.output_plan:
429 if not isinstance(sharding_plan.output_plan, dict):
430 raise TypeError(f"output_plan must be a dict, but got {type(sharding_plan.output_plan)}")
431 forward_part.update(sharding_plan.output_plan)
433 if forward_part:
434 normalized_plan["forward"] = forward_part
436 if sharding_plan.return_local_tensor:
437 return_local_tensor_list = sharding_plan.return_local_tensor
439 # Convert sharding_plan to Layout objects
440 converted_plan = _convert_sharding_plan(normalized_plan, device_mesh)
442 if not isinstance(model, Module):
443 return _shard_callable(model, converted_plan)
445 param_sharding_plan = converted_plan.get("parameter")
446 forward_sharding_plan = converted_plan.get("forward")
448 if param_sharding_plan is not None:
449 for param_name, layout in param_sharding_plan.items():
450 if not isinstance(layout, Layout):
451 raise ValueError(f"In python shard_module, the type of setting in parameter_plan must be Layout, "
452 f"but got type {type(layout)}")
453 result = platform.search_parameter_by_name(model, param_name)
454 if not result:
455 raise ValueError(f"{param_name} is configured with a layout, but no instance was found.")
456 _, _, param = result
457 layout.placement_to_tensor_map(param.dim())
458 param = platform.set_layout_into_parameter(param, layout)
459 platform.update_parameter_by_name(model, result, param)
461 if forward_sharding_plan is not None:
462 _register_hook(model, forward_sharding_plan)
464 if return_local_tensor_list is not None:
465 _register_local_tensor_hook(model, return_local_tensor_list)
467 return model
470def parallelize_value_and_grad(fn, weights, sens=None):
471 """
472 A wrapper function to generate the function to calculate forward output and gradient for the parallel scenario.
474 Args:
475 fn (Union[Cell, Function]): Function to do grad operation.
476 weights (Union[ParameterTuple, Parameter, list[Parameter]]):
477 The parameters of the training network that need to
478 calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
479 sens (Union[list(float), tuple(float)], optional): The sensitivity for grad operation. Default: "None".
480 - If the fn only have one output, the sens must be None, and it will be attached automatically.
481 - If the fn have multiple outputs:
482 1) If the sens is None, only handle the first sensitivity, and set the remaining sensitivity to 0.
483 2) If the sens is not None, the lengths of sens and outputs of fn must be equal.
485 Returns:
486 Function, the derivative function used to compute the gradient of a given function.
487 For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like
488 `((out1, out2), gradient)` .
490 Raises:
491 TypeError: If type of Args does not belong to required ones.
493 Supported Platforms:
494 ``Ascend``
495 """
496 from mindspore import ops # pylint: disable=import-outside-toplevel
497 grad_fn = ops.GradOperation(get_by_list=True, sens_param=True)
499 # use CellWrapper to solve two problems:
500 # 1. avoid running the forward fn or cell twice
501 # 2. if the input of parallize_value_and_grad is cell and it is directly used as the input for grad,
502 # the operations before and after its __call__ function will not enter the auto-diff process.
503 class CellWrapper(Module):
504 def __init__(self, net):
505 super().__init__(auto_prefix=False)
506 self.network = net
508 def construct(self, *args, **kwargs):
509 return self.network(*args, **kwargs)
511 def forward(self, *args, **kwargs):
512 return self.network(*args, **kwargs)
514 fn = CellWrapper(fn)
515 fn.set_grad() # avoid running the forward fn or cell twice
517 def wrapper(*args, **kwargs):
518 loss_value = fn(*args, **kwargs)
519 p_sens = None
521 if isinstance(loss_value, (list, tuple)):
522 # There are multiple outputs, requiring multiple sens
523 p_sens = []
525 if sens is None:
526 # if sens is None, only handle the first sens, and set the remaining sens to 0
527 loss_0 = loss_value[0]
528 if isinstance(loss_0, DTensor):
529 repeat_num = loss_0.layout.repeat_num()
530 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.local_shape, 1.0 / repeat_num)
531 else:
532 sens_0 = ops.fill(ops.DType()(loss_0), loss_0.shape, 1.0)
533 p_sens.append(sens_0)
535 for i in range(1, len(loss_value)):
536 loss_i = loss_value[i]
537 if isinstance(loss_i, DTensor):
538 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 0.0)
539 else:
540 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 0.0)
541 p_sens.append(sens_i)
543 else:
544 # sens is not None
545 if not isinstance(sens, list) and not isinstance(sens, tuple):
546 raise TypeError("if the loss is list or tuple, the sens must be None or list or tuple")
548 all_float = all(isinstance(item, float) for item in sens)
549 if not all_float:
550 raise TypeError("if sens is not None, it should be list of float or tuple of float")
552 if len(sens) != len(loss_value):
553 raise TypeError(f"the len of loss is {len(loss_value)}, but the len of sens is {len(sens)}")
555 for _, loss_i in enumerate(loss_value):
556 if isinstance(loss_i, DTensor):
557 repeat_num = loss_i.layout.repeat_num()
558 sens_i = ops.fill(ops.DType()(loss_i), loss_i.local_shape, 1.0 / repeat_num)
559 else:
560 sens_i = ops.fill(ops.DType()(loss_i), loss_i.shape, 1.0)
561 p_sens.append(sens_i)
563 else:
564 # loss is tensor
565 if sens is not None:
566 raise TypeError(f"the fn only have one output, the sens must be None, but it is {sens}")
567 if isinstance(loss_value, DTensor):
568 repeat_num = loss_value.layout.repeat_num()
569 p_sens = ops.fill(ops.DType()(loss_value), loss_value.local_shape, 1.0 / repeat_num)
571 else:
572 p_sens = ops.fill(ops.DType()(loss_value), loss_value.shape, 1.0)
574 grads = grad_fn(fn, weights)(*args, **kwargs, sens=p_sens)
575 return loss_value, grads
577 return wrapper