Coverage for hyper_parallel / core / shard / sharding_plan.py: 100%

8 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""sharding plan""" 

16from dataclasses import dataclass 

17from typing import Optional, Dict, Any, List 

18 

19 

20@dataclass 

21class ShardingPlan: 

22 """ 

23 Define a distribution scheme to partition a model across multiple computation nodes.  

24 This class configures the `plan` for weight sharding based on layouts, while `input_plan` and `output_plan` 

25 control the data distribution of the model's entries and exits. Use 'return_local_tensor' to mark 

26 submodules that should output standard local Tensors instead of distributed ones. 

27 

28 Attributes: 

29 plan (Dict[str, Any], optional): Mapping of parameter identifiers to  

30 Layout instances for weight partitioning. Defaults to None. 

31 input_plan (Dict[str, Any], optional): Configuration for input data  

32 layouts at the root or submodule level. Defaults to None. 

33 output_plan (Dict[str, Any], optional): Configuration for output data  

34 layouts at the root or submodule level. Defaults to None. 

35 return_local_tensor (List[str], optional): Identifiers of modules  

36 whose results should be converted to non-distributed Tensors. Defaults to None. 

37 

38 Example: 

39 >>> from hyper_parallel import DeviceMesh, Layout, ShardingPlan, shard_module 

40 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp")) 

41 >>> sharding_plan = ShardingPlan( 

42 ... plan={"weight": (Replicate(), Shard(1))}, 

43 ... input_plan={"input": (Shard(0), Replicate()), "relu.input": (Shard(0), Replicate())}, 

44 ... output_plan={"output": (Shard(0), Replicate()), "relu.output": (Shard(0), Replicate())}, 

45 ... return_local_tensor=["relu"] 

46 ... ) 

47 >>> model = shard_module(model, mesh, sharding_plan) 

48 """ 

49 plan: Optional[Dict[str, Any]] = None 

50 input_plan: Optional[Dict[str, Any]] = None 

51 output_plan: Optional[Dict[str, Any]] = None 

52 return_local_tensor: Optional[List[str]] = None