Coverage for hyper_parallel / platform / torch / fully_shard / state.py: 56%
128 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025-2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch HSDP cell state"""
16from typing import List, Optional
17import torch
18from hyper_parallel.core.dtensor import DTensor
19from hyper_parallel.core.fully_shard.hsdp_state import HSDPState
20from hyper_parallel.core.fully_shard.hsdp_utils import _get_param_module_infos
21from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2
22from hyper_parallel.platform.torch.fully_shard.utils import HSDPMeshInfo, CPUOffloadPolicy
25def _to_dtype_if_needed(
26 tensor: torch.Tensor, dtype: Optional[torch.dtype]
27) -> torch.Tensor:
28 """Cast tensor to the given dtype if it differs from current dtype.
30 Args:
31 tensor: The input tensor to potentially cast.
32 dtype: Target dtype. If None or same as tensor dtype, no-op.
33 """
34 if dtype is not None and tensor.dtype != dtype:
35 return tensor.to(dtype)
36 return tensor
39class TorchHSDPStateV2(HSDPState):
40 """Torch HSDP cell state"""
41 def __init__(self, cell, mesh_info, config, platform, device):
42 super().__init__(cell, mesh_info, config, platform, device)
43 # Do ReduceScatter/AllReduce for grad
44 self.device = device
45 self.mp_policy = config.mp_policy
46 self.offload_policy = config.offload_policy
47 self.reduce_grads = True
48 # Reshard parameter after backward
49 self.reshard_after_backward = True
50 # Requires AllReduce for grad When HSDP
51 self.requires_all_reduce = True
52 self._use_post_forward_mesh = False
53 # Reduce Op type for gradient reduction, default to AVG.
54 self.reduce_op_type = torch.distributed.ReduceOp.AVG
55 self._validate_cpu_offload_params()
56 self._init_mp_dtypes()
58 def _move_states_to_device(self):
59 """move states to device"""
60 # TODO: @celia DTensor support
61 for param in self.cell.parameters():
62 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
63 continue
64 if param.device == self.device or param.device.type == "meta":
65 continue
66 param.data = param.to(self.device)
67 for buffer in self.cell.buffers():
68 if buffer.device == self.device or buffer.device.type == "meta":
69 continue
70 buffer.data = buffer.to(self.device)
72 def _init_hsdp_params(self):
73 """init hsdp parameters for cell"""
74 # Cell 树内的全部parameters
75 filtered_params = []
76 for _, param in self.cell.named_parameters():
77 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized:
78 # 在HSDPParam._init_sharded_param中添加该属性,避免重复初始化
79 # 通过_setattr_重新给cell绑定了param后,named_parameters会重复遍历到该param
80 continue
81 filtered_params.append(param)
83 module_infos = _get_param_module_infos(filtered_params, [self.cell,])
84 for param, module_info in zip(filtered_params, module_infos):
85 hsdp_param = TorchHSDPParamV2(param,
86 module_info,
87 self.mesh_info,
88 mp_policy=self.mp_policy,
89 offload_policy=self.offload_policy,
90 device=self.device,
91 )
92 self.hsdp_params.append(hsdp_param)
93 if hsdp_param.is_sharded:
94 # TODO: 这个可能不需要了,后续根据mesh处理是否切分。
95 self.sharded_hsdp_params.append(hsdp_param)
97 def _init_mp_dtypes(self):
98 """init mp dtypes for hsdp parameters"""
99 for hsdp_param in self.hsdp_params:
100 hsdp_param.init_dtype_attrs(self.mp_policy)
101 trainable_params: list[TorchHSDPParamV2] = [
102 p for p in self.hsdp_params if p.sharded_param.requires_grad
103 ]
104 orig_dtypes = {p.orig_dtype for p in trainable_params}
105 reduce_dtypes = {p.reduce_dtype for p in trainable_params}
106 if len(trainable_params) > 0 and len(orig_dtypes) != 1:
107 raise AssertionError(
108 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}"
109 )
110 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None
111 if len(trainable_params) > 0 and len(reduce_dtypes) != 1:
112 raise AssertionError(
113 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}"
114 )
115 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None
117 def _validate_cpu_offload_params(self):
118 if not isinstance(self.offload_policy, CPUOffloadPolicy):
119 return
120 hsdp_params_not_on_cpu = [
121 hsdp_param
122 for hsdp_param in self.hsdp_params
123 if hsdp_param.sharded_param.device.type != "cpu"
124 ]
125 if hsdp_params_not_on_cpu:
126 raise RuntimeError(
127 "HSDP parameters should be materialized on CPU when enabling CPU offloading. "
128 'For example, load a CPU state dict or call module.to_empty(device="cpu"). '
129 "Found following parameters on non-CPU device: "
130 f"{[(hsdp_param._param_fqn, hsdp_param.sharded_param.device) for hsdp_param in hsdp_params_not_on_cpu]}\n"
131 )
133 def lazy_init(self):
134 raise NotImplementedError("lazy_init not implemented in TorchHSDPStateV2")
136 def reshard(self,):
137 # TODO:补齐reshard接口,当前我们不考虑reshard_after_forward配置是int的情况,只考虑True/False
138 # if self.scheduler_state == FSDPSchedulerState.FORWARD:
139 # if not self.reshard_after_forward:
140 # return
141 # if self._use_post_forward_mesh:
142 # # TODO: support reshard_after_forward=(int)
143 # raise NotImplementedError(f"For reshard, need support reshard_after_forward=(int).")
144 # self._to_sharded_post_forward()
145 # self._reshard_after_forward_event = self.device_handle.Event()
146 # if self._reshard_after_forward_event is not None:
147 # self._reshard_after_forward_event.record()
148 # return
149 self.shard()
151 def _apply_reduced_grad(self, hsdp_param, reduced_grad):
152 """
153 Apply reduced gradient to the sharded parameter.
155 Reshapes ``reduced_grad`` to match the local shard, optionally
156 offloads to CPU, then accumulates or assigns onto
157 ``hsdp_param.sharded_param.grad``.
159 Args:
160 hsdp_param (TorchHSDPParamV2): The HSDP parameter wrapper.
161 reduced_grad (torch.Tensor): Gradient after reduce-scatter
162 and/or all-reduce.
163 """
164 sharded_grad = hsdp_param.sharded_param.grad
165 sharded_param_local_shape = (
166 hsdp_param.sharded_param.local_shape
167 if isinstance(hsdp_param.sharded_param, DTensor)
168 else hsdp_param.sharded_param.shape
169 )
170 reduced_grad = reduced_grad.view(sharded_param_local_shape)
171 reduced_grad = _to_dtype_if_needed(reduced_grad, self._orig_dtype)
172 to_accumulate_grad = sharded_grad is not None
173 need_synchronize = False
174 if hsdp_param.offload_to_cpu:
175 non_blocking = hsdp_param.pin_memory and not to_accumulate_grad
176 reduced_grad = reduced_grad.to(
177 torch.device("cpu"), non_blocking=non_blocking
178 )
179 need_synchronize = True
180 if sharded_grad is None:
181 hsdp_param.sharded_param.grad = reduced_grad
182 else:
183 hsdp_param.sharded_param.grad += reduced_grad
184 if hsdp_param.unsharded_accumulated_grad_data is not None:
185 hsdp_param.unsharded_accumulated_grad_data = None
186 elif hsdp_param.unsharded_param.grad is not None:
187 hsdp_param.unsharded_param.grad = None
188 return need_synchronize
190 def post_backward(self, *unused):
191 for hsdp_param in self.hsdp_params:
192 hsdp_param.accumulate_unsharded_grad_if_needed()
193 if not self.reduce_grads:
194 if self.reshard_after_backward:
195 self.reshard()
196 for hsdp_param in self.hsdp_params:
197 hsdp_param.to_accumulated_grad_if_needed()
198 return
199 hsdp_params_with_grad: List[TorchHSDPParamV2] = []
200 unsharded_grads: List[torch.Tensor] = []
201 for hsdp_param in self.hsdp_params:
202 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None:
203 continue
204 # Frozen parameters (requires_grad=False) produce no
205 # gradient — skip all reduce-scatter / all-reduce work.
206 if not hsdp_param.sharded_param.requires_grad:
207 continue
208 if hsdp_param.shard_world_size > 1:
209 if hsdp_param.unsharded_param.grad is None:
210 # Parameter requires grad but was not used in
211 # forward — all ranks skip consistently.
212 continue
213 reduced_grad, _ = hsdp_param.reduce_scatter_grad(
214 dtype=self._reduce_dtype,
215 reduce_op=self.reduce_op_type
216 )
217 if self.requires_all_reduce and hsdp_param.replicate_world_size > 1:
218 assert isinstance(hsdp_param.mesh_info, HSDPMeshInfo)
219 reduced_grad, _ = hsdp_param.all_reduce_grad(
220 grad=reduced_grad,
221 reduce_op=self.reduce_op_type,
222 )
223 # Bind the reduced gradient to hsdp_param.sharded_param
224 need_synchronize = self._apply_reduced_grad(hsdp_param, reduced_grad)
225 if need_synchronize:
226 if self.device.type == "npu":
227 torch.npu.current_stream().synchronize()
228 elif self.device.type == "cuda":
229 torch.cuda.current_stream().synchronize()
230 else:
231 raise NotImplementedError(f"Unsupported device type {self.device.type} for synchronization after CPU offload.")
233 if self.reshard_after_backward:
234 self.reshard()
236 def set_requires_grad_sync(self, requires_grad_sync):
237 """set requires grad sync flag to control gradient sync."""
238 self.reduce_grads = requires_grad_sync
240 def set_reduce_op_type(self, reduce_op_type: str):
241 """set reduce op type for gradient reduction."""
242 fsdp_support_reduce_op = {
243 "sum": torch.distributed.ReduceOp.SUM,
244 "avg": torch.distributed.ReduceOp.AVG,
245 }
246 if reduce_op_type not in fsdp_support_reduce_op:
247 raise ValueError(f"Unsupported reduce op type {reduce_op_type}, supported types are {list(fsdp_support_reduce_op.keys())}")
248 reduce_op: str = reduce_op_type.lower().strip()
249 self.reduce_op_type = fsdp_support_reduce_op[reduce_op]