Coverage for hyper_parallel / core / fully_shard / hsdp_utils.py: 83%
71 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP optimizer shared level"""
16from dataclasses import dataclass, field
17from typing import Any, List, Optional, Sequence, Tuple
18from enum import auto, Enum
19from torch import nn
21class OptimizerLevel(Enum):
22 """
23 Optimizer level:
24 - SHARD_OPT:
25 Splitting is performed on optimizer state.
26 - SHARD_OPT_GRAD:
27 Splitting is performed on optimizer state, and gradients.
28 - SHARD_OPT_GRAD_PARAM:
29 Splitting is performed on optimizer state, gradients and weights.
30 """
31 SHARD_OPT = auto()
32 SHARD_OPT_GRAD = auto()
33 SHARD_OPT_GRAD_PARAM = auto()
35class GroupInfo:
36 """
37 GroupInfo
38 """
39 def __init__(self, group_name, group, rank_size):
40 self.group_name = group_name
41 self.group = group
42 self.rank_size = rank_size
45class HSDPConfigV2:
46 """HSDPConfigV2 inspect by torch fully_shard"""
48 def __init__(self, mesh, reshard_after_forward, shard_placement_fn, mp_policy, offload_policy, ignored_param,
49 reduce_dtype=None, comm_async=False, comm_fusion=False, bucket_size=-1):
50 """
51 HSDP config init method
52 Args:
53 shard_size: optimizer weight sharded size.
54 threshold: minimum weight size to shard.
55 requires_acc_grad: requires gradient accumulation.
56 grad_scale: use grad_scale to scale grad.
57 shard_level: optimizer shard level.
58 use_eager_hook: use eager hook or graph hook to implement hsdp.
59 reduce_dtype: set gradient reduce dtype.
60 comm_async: use async communication op for grad reduction.
61 comm_fusion: use communication op fusion to reduce the number of communication op.
62 bucket_size: the size of comm fusion buffer.
63 """
64 self.mesh = mesh
65 self.reshard_after_forward = reshard_after_forward
66 self.shard_placement_fn = shard_placement_fn
67 self.mp_policy = mp_policy
68 self.offload_policy = offload_policy
69 self.reduce_dtype = self.mp_policy.reduce_dtype if self.mp_policy else None
70 # TODO: 下方属性待删除
71 self.comm_async = False
72 self.comm_fusion = False
73 self.bucket_size = 9999
74 self.grad_fusion = False
76class ShardedState(Enum):
77 """
78 Parameter shard state
79 """
80 SHARDED = auto()
81 SHARDED_POST_FORWARD = auto()
82 UNSHARDED = auto()
84class FSDPSchedulerState(Enum):
85 """
86 Scheduler state:
87 - PRE_FORWARD:
88 already run hook before forward.
89 - FORWARD:
90 already run hook after forward.
91 - PRE_BACKWARD:
92 already run hook before backward.
93 - PRE_BACKWARD:
94 already run hook after backward.
95 """
96 PRE_FORWARD = auto()
97 FORWARD = auto()
98 PRE_BACKWARD = auto()
99 BACKWARD = auto()
102@dataclass
103class ParamModuleInfo:
104 """
105 Tracks parameter ownership and supports shared weights in HSDP.
107 This dataclass maintains the mapping between a parameter and its module(s),
108 enabling parameter swapping during sharding/unsharding transitions. Shared
109 weights are parameters referenced by multiple modules (e.g., tied embeddings).
111 This class tracks all references to ensure proper parameter replacement during
112 sharding/unsharding operations.
114 Attributes:
115 module: The module that owns this parameter.
116 param_name: Attribute name of the parameter in the module (e.g., "weight").
117 shared_modules: List of other modules sharing this same parameter object.
118 shared_param_names: Corresponding parameter names in shared_modules (aligned by index).
119 """
120 module: nn.Module
121 param_name: str
122 shared_modules: List[nn.Module] = field(default_factory=list)
123 shared_param_names: List[str] = field(default_factory=list)
126@dataclass
127class ExtensionsData:
128 """
129 Stores metadata for custom all-gather extensions.
131 This enables users to implement custom pre/post all-gather transforms
132 by passing metadata between the two phases. The input sizes are saved
133 to properly reshape the gathered outputs back to their original dimensions.
135 Attributes:
136 all_gather_metadata: Custom metadata passed from pre to post all-gather.
137 all_gather_input_sizes: Original tensor shapes before flattening for all-gather.
138 """
139 all_gather_metadata: Optional[Any] = None
140 all_gather_input_sizes: Sequence[Tuple[int, ...]] = ()
142 def clear(self):
143 """Reset all extension data to default values."""
144 self.all_gather_metadata = None
145 self.all_gather_input_sizes = ()
148def _named_parameters_with_duplicates(
149 module: nn.Module, **kwargs: Any
150) -> list[tuple[str, nn.Parameter]]:
151 """
152 This API is required as some modules overwrite `named_parameters()` but do not support
153 `remove_duplicate`.
154 """
155 if "remove_duplicate" in kwargs:
156 raise AssertionError(
157 "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
158 )
159 kwargs["remove_duplicate"] = False
160 try:
161 ret = list(module.named_parameters(**kwargs))
162 except AssertionError:
163 kwargs.pop("remove_duplicate")
164 ret = list(module.named_parameters(**kwargs))
165 return ret
167def _get_param_module_infos(
168 params: list[nn.Parameter], modules: tuple[nn.Module, ...]
169) -> list['ParamModuleInfo']:
170 """
171 Shared parameter: lin1.weight = lin2.weight
172 Shared module: mlp.lin1 = mlp.lin2
173 We do not remove duplicates when traversing both modules and parameters to
174 find shared modules' parameters and shared parameters within a module.
175 """
176 params_set = set(params)
177 param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {}
178 for module in modules:
179 for _, submodule in module.named_modules(remove_duplicate=False):
180 for param_name, param in _named_parameters_with_duplicates(
181 submodule, recurse=False
182 ):
183 if param in params_set:
184 if param not in param_to_module_info:
185 param_to_module_info[param] = ParamModuleInfo(
186 submodule, param_name
187 )
188 else:
189 param_to_module_info[param].shared_modules.append(submodule)
190 param_to_module_info[param].shared_param_names.append(
191 param_name
192 )
193 if len(param_to_module_info) != len(params):
194 raise AssertionError(f"Some parameters are not in the module tree of {modules}")
195 return [param_to_module_info[param] for param in params]