Coverage for hyper_parallel / platform / torch / fully_shard / state.py: 56%

128 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch HSDP cell state""" 

16from typing import List, Optional 

17import torch 

18from hyper_parallel.core.dtensor import DTensor 

19from hyper_parallel.core.fully_shard.hsdp_state import HSDPState 

20from hyper_parallel.core.fully_shard.hsdp_utils import _get_param_module_infos 

21from hyper_parallel.platform.torch.fully_shard.param import TorchHSDPParamV2 

22from hyper_parallel.platform.torch.fully_shard.utils import HSDPMeshInfo, CPUOffloadPolicy 

23 

24 

25def _to_dtype_if_needed( 

26 tensor: torch.Tensor, dtype: Optional[torch.dtype] 

27) -> torch.Tensor: 

28 """Cast tensor to the given dtype if it differs from current dtype. 

29 

30 Args: 

31 tensor: The input tensor to potentially cast. 

32 dtype: Target dtype. If None or same as tensor dtype, no-op. 

33 """ 

34 if dtype is not None and tensor.dtype != dtype: 

35 return tensor.to(dtype) 

36 return tensor 

37 

38 

39class TorchHSDPStateV2(HSDPState): 

40 """Torch HSDP cell state""" 

41 def __init__(self, cell, mesh_info, config, platform, device): 

42 super().__init__(cell, mesh_info, config, platform, device) 

43 # Do ReduceScatter/AllReduce for grad 

44 self.device = device 

45 self.mp_policy = config.mp_policy 

46 self.offload_policy = config.offload_policy 

47 self.reduce_grads = True 

48 # Reshard parameter after backward 

49 self.reshard_after_backward = True 

50 # Requires AllReduce for grad When HSDP 

51 self.requires_all_reduce = True 

52 self._use_post_forward_mesh = False 

53 # Reduce Op type for gradient reduction, default to AVG. 

54 self.reduce_op_type = torch.distributed.ReduceOp.AVG 

55 self._validate_cpu_offload_params() 

56 self._init_mp_dtypes() 

57 

58 def _move_states_to_device(self): 

59 """move states to device""" 

60 # TODO: @celia DTensor support 

61 for param in self.cell.parameters(): 

62 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

63 continue 

64 if param.device == self.device or param.device.type == "meta": 

65 continue 

66 param.data = param.to(self.device) 

67 for buffer in self.cell.buffers(): 

68 if buffer.device == self.device or buffer.device.type == "meta": 

69 continue 

70 buffer.data = buffer.to(self.device) 

71 

72 def _init_hsdp_params(self): 

73 """init hsdp parameters for cell""" 

74 # Cell 树内的全部parameters 

75 filtered_params = [] 

76 for _, param in self.cell.named_parameters(): 

77 if hasattr(param, "_hsdp_param_initialized") and param._hsdp_param_initialized: 

78 # 在HSDPParam._init_sharded_param中添加该属性,避免重复初始化 

79 # 通过_setattr_重新给cell绑定了param后,named_parameters会重复遍历到该param 

80 continue 

81 filtered_params.append(param) 

82 

83 module_infos = _get_param_module_infos(filtered_params, [self.cell,]) 

84 for param, module_info in zip(filtered_params, module_infos): 

85 hsdp_param = TorchHSDPParamV2(param, 

86 module_info, 

87 self.mesh_info, 

88 mp_policy=self.mp_policy, 

89 offload_policy=self.offload_policy, 

90 device=self.device, 

91 ) 

92 self.hsdp_params.append(hsdp_param) 

93 if hsdp_param.is_sharded: 

94 # TODO: 这个可能不需要了,后续根据mesh处理是否切分。 

95 self.sharded_hsdp_params.append(hsdp_param) 

96 

97 def _init_mp_dtypes(self): 

98 """init mp dtypes for hsdp parameters""" 

99 for hsdp_param in self.hsdp_params: 

100 hsdp_param.init_dtype_attrs(self.mp_policy) 

101 trainable_params: list[TorchHSDPParamV2] = [ 

102 p for p in self.hsdp_params if p.sharded_param.requires_grad 

103 ] 

104 orig_dtypes = {p.orig_dtype for p in trainable_params} 

105 reduce_dtypes = {p.reduce_dtype for p in trainable_params} 

106 if len(trainable_params) > 0 and len(orig_dtypes) != 1: 

107 raise AssertionError( 

108 f"hsdp expects uniform original parameter dtype but got {orig_dtypes}" 

109 ) 

110 self._orig_dtype = next(iter(orig_dtypes)) if trainable_params else None 

111 if len(trainable_params) > 0 and len(reduce_dtypes) != 1: 

112 raise AssertionError( 

113 f"hsdp expects uniform reduce dtype but got {reduce_dtypes}" 

114 ) 

115 self._reduce_dtype = next(iter(reduce_dtypes)) if trainable_params else None 

116 

117 def _validate_cpu_offload_params(self): 

118 if not isinstance(self.offload_policy, CPUOffloadPolicy): 

119 return 

120 hsdp_params_not_on_cpu = [ 

121 hsdp_param 

122 for hsdp_param in self.hsdp_params 

123 if hsdp_param.sharded_param.device.type != "cpu" 

124 ] 

125 if hsdp_params_not_on_cpu: 

126 raise RuntimeError( 

127 "HSDP parameters should be materialized on CPU when enabling CPU offloading. " 

128 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' 

129 "Found following parameters on non-CPU device: " 

130 f"{[(hsdp_param._param_fqn, hsdp_param.sharded_param.device) for hsdp_param in hsdp_params_not_on_cpu]}\n" 

131 ) 

132 

133 def lazy_init(self): 

134 raise NotImplementedError("lazy_init not implemented in TorchHSDPStateV2") 

135 

136 def reshard(self,): 

137 # TODO:补齐reshard接口,当前我们不考虑reshard_after_forward配置是int的情况,只考虑True/False 

138 # if self.scheduler_state == FSDPSchedulerState.FORWARD: 

139 # if not self.reshard_after_forward: 

140 # return 

141 # if self._use_post_forward_mesh: 

142 # # TODO: support reshard_after_forward=(int) 

143 # raise NotImplementedError(f"For reshard, need support reshard_after_forward=(int).") 

144 # self._to_sharded_post_forward() 

145 # self._reshard_after_forward_event = self.device_handle.Event() 

146 # if self._reshard_after_forward_event is not None: 

147 # self._reshard_after_forward_event.record() 

148 # return 

149 self.shard() 

150 

151 def _apply_reduced_grad(self, hsdp_param, reduced_grad): 

152 """ 

153 Apply reduced gradient to the sharded parameter. 

154 

155 Reshapes ``reduced_grad`` to match the local shard, optionally 

156 offloads to CPU, then accumulates or assigns onto 

157 ``hsdp_param.sharded_param.grad``. 

158 

159 Args: 

160 hsdp_param (TorchHSDPParamV2): The HSDP parameter wrapper. 

161 reduced_grad (torch.Tensor): Gradient after reduce-scatter 

162 and/or all-reduce. 

163 """ 

164 sharded_grad = hsdp_param.sharded_param.grad 

165 sharded_param_local_shape = ( 

166 hsdp_param.sharded_param.local_shape 

167 if isinstance(hsdp_param.sharded_param, DTensor) 

168 else hsdp_param.sharded_param.shape 

169 ) 

170 reduced_grad = reduced_grad.view(sharded_param_local_shape) 

171 reduced_grad = _to_dtype_if_needed(reduced_grad, self._orig_dtype) 

172 to_accumulate_grad = sharded_grad is not None 

173 need_synchronize = False 

174 if hsdp_param.offload_to_cpu: 

175 non_blocking = hsdp_param.pin_memory and not to_accumulate_grad 

176 reduced_grad = reduced_grad.to( 

177 torch.device("cpu"), non_blocking=non_blocking 

178 ) 

179 need_synchronize = True 

180 if sharded_grad is None: 

181 hsdp_param.sharded_param.grad = reduced_grad 

182 else: 

183 hsdp_param.sharded_param.grad += reduced_grad 

184 if hsdp_param.unsharded_accumulated_grad_data is not None: 

185 hsdp_param.unsharded_accumulated_grad_data = None 

186 elif hsdp_param.unsharded_param.grad is not None: 

187 hsdp_param.unsharded_param.grad = None 

188 return need_synchronize 

189 

190 def post_backward(self, *unused): 

191 for hsdp_param in self.hsdp_params: 

192 hsdp_param.accumulate_unsharded_grad_if_needed() 

193 if not self.reduce_grads: 

194 if self.reshard_after_backward: 

195 self.reshard() 

196 for hsdp_param in self.hsdp_params: 

197 hsdp_param.to_accumulated_grad_if_needed() 

198 return 

199 hsdp_params_with_grad: List[TorchHSDPParamV2] = [] 

200 unsharded_grads: List[torch.Tensor] = [] 

201 for hsdp_param in self.hsdp_params: 

202 if not hasattr(hsdp_param, "_unsharded_param") or hsdp_param.unsharded_param is None: 

203 continue 

204 # Frozen parameters (requires_grad=False) produce no 

205 # gradient — skip all reduce-scatter / all-reduce work. 

206 if not hsdp_param.sharded_param.requires_grad: 

207 continue 

208 if hsdp_param.shard_world_size > 1: 

209 if hsdp_param.unsharded_param.grad is None: 

210 # Parameter requires grad but was not used in 

211 # forward — all ranks skip consistently. 

212 continue 

213 reduced_grad, _ = hsdp_param.reduce_scatter_grad( 

214 dtype=self._reduce_dtype, 

215 reduce_op=self.reduce_op_type 

216 ) 

217 if self.requires_all_reduce and hsdp_param.replicate_world_size > 1: 

218 assert isinstance(hsdp_param.mesh_info, HSDPMeshInfo) 

219 reduced_grad, _ = hsdp_param.all_reduce_grad( 

220 grad=reduced_grad, 

221 reduce_op=self.reduce_op_type, 

222 ) 

223 # Bind the reduced gradient to hsdp_param.sharded_param 

224 need_synchronize = self._apply_reduced_grad(hsdp_param, reduced_grad) 

225 if need_synchronize: 

226 if self.device.type == "npu": 

227 torch.npu.current_stream().synchronize() 

228 elif self.device.type == "cuda": 

229 torch.cuda.current_stream().synchronize() 

230 else: 

231 raise NotImplementedError(f"Unsupported device type {self.device.type} for synchronization after CPU offload.") 

232 

233 if self.reshard_after_backward: 

234 self.reshard() 

235 

236 def set_requires_grad_sync(self, requires_grad_sync): 

237 """set requires grad sync flag to control gradient sync.""" 

238 self.reduce_grads = requires_grad_sync 

239 

240 def set_reduce_op_type(self, reduce_op_type: str): 

241 """set reduce op type for gradient reduction.""" 

242 fsdp_support_reduce_op = { 

243 "sum": torch.distributed.ReduceOp.SUM, 

244 "avg": torch.distributed.ReduceOp.AVG, 

245 } 

246 if reduce_op_type not in fsdp_support_reduce_op: 

247 raise ValueError(f"Unsupported reduce op type {reduce_op_type}, supported types are {list(fsdp_support_reduce_op.keys())}") 

248 reduce_op: str = reduce_op_type.lower().strip() 

249 self.reduce_op_type = fsdp_support_reduce_op[reduce_op]