Coverage for hyper_parallel / core / shard / sharding_plan.py: 100%
8 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""sharding plan"""
16from dataclasses import dataclass
17from typing import Optional, Dict, Any, List
20@dataclass
21class ShardingPlan:
22 """
23 Define a distribution scheme to partition a model across multiple computation nodes.
24 This class configures the `plan` for weight sharding based on layouts, while `input_plan` and `output_plan`
25 control the data distribution of the model's entries and exits. Use 'return_local_tensor' to mark
26 submodules that should output standard local Tensors instead of distributed ones.
28 Attributes:
29 plan (Dict[str, Any], optional): Mapping of parameter identifiers to
30 Layout instances for weight partitioning. Defaults to None.
31 input_plan (Dict[str, Any], optional): Configuration for input data
32 layouts at the root or submodule level. Defaults to None.
33 output_plan (Dict[str, Any], optional): Configuration for output data
34 layouts at the root or submodule level. Defaults to None.
35 return_local_tensor (List[str], optional): Identifiers of modules
36 whose results should be converted to non-distributed Tensors. Defaults to None.
38 Example:
39 >>> from hyper_parallel import DeviceMesh, Layout, ShardingPlan, shard_module
40 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp"))
41 >>> sharding_plan = ShardingPlan(
42 ... plan={"weight": (Replicate(), Shard(1))},
43 ... input_plan={"input": (Shard(0), Replicate()), "relu.input": (Shard(0), Replicate())},
44 ... output_plan={"output": (Shard(0), Replicate()), "relu.output": (Shard(0), Replicate())},
45 ... return_local_tensor=["relu"]
46 ... )
47 >>> model = shard_module(model, mesh, sharding_plan)
48 """
49 plan: Optional[Dict[str, Any]] = None
50 input_plan: Optional[Dict[str, Any]] = None
51 output_plan: Optional[Dict[str, Any]] = None
52 return_local_tensor: Optional[List[str]] = None