Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / torch / fully_shard / pack_utils.py: 79%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Packing helpers for fully_shard communication buffers.""" 

16 

17from __future__ import annotations 

18 

19import math 

20from dataclasses import dataclass 

21from typing import Any, Literal, Optional 

22 

23import torch 

24 

25from hyper_parallel.core.dtensor.placement_types import StridedShard 

26 

27 

28@dataclass(frozen=True) 

29class ReduceScatterPlan: 

30 """Describe how local tensors map to packed communication layouts.""" 

31 

32 pack_kind: Literal[ 

33 "identity_dim0", 

34 "same_dim_strided_identity_dim0", 

35 "chunk_cat_non_dim0", 

36 ] 

37 shard_dim: int 

38 world_size: int 

39 packed_shape: torch.Size 

40 packed_tensor_shape: torch.Size 

41 unpacked_shape: torch.Size 

42 

43 

44@dataclass(frozen=True) 

45class _SameDimStridedLayoutContext: 

46 target_dim: int 

47 shard_mesh_dim: int 

48 placements: tuple[Any, ...] 

49 orig_placements: tuple[Any, ...] 

50 

51 

52def _has_strided_shard_layout(hsdp_param: Any) -> bool: 

53 placements = getattr(hsdp_param, "_spmd_placements", ()) or () 

54 return any(isinstance(placement, StridedShard) for placement in placements) 

55 

56 

57def _resolve_same_dim_strided_context( 

58 hsdp_param: Any, 

59) -> Optional[_SameDimStridedLayoutContext]: 

60 if not _has_strided_shard_layout(hsdp_param): 

61 return None 

62 if not getattr(hsdp_param, "uses_param_shard", False): 

63 return None 

64 if not getattr(hsdp_param, "_orig_param_is_dtensor", False): 

65 return None 

66 target_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", None) 

67 if target_dim is None: 

68 return None 

69 shard_mesh_dim = getattr(hsdp_param, "_spmd_shard_mesh_dim", None) 

70 placements = tuple(getattr(hsdp_param, "_spmd_placements", ()) or ()) 

71 if shard_mesh_dim is None or shard_mesh_dim >= len(placements): 

72 return None 

73 if not isinstance(placements[shard_mesh_dim], StridedShard): 

74 return None 

75 orig_placements = getattr(hsdp_param, "_orig_dtensor_placements", None) 

76 if orig_placements is None: 

77 return None 

78 return _SameDimStridedLayoutContext( 

79 target_dim=target_dim, 

80 shard_mesh_dim=shard_mesh_dim, 

81 placements=placements, 

82 orig_placements=tuple(orig_placements), 

83 ) 

84 

85 

86def _placements_match_target_dim_only( 

87 placements: tuple[Any, ...], 

88 target_dim: int, 

89) -> bool: 

90 return all( 

91 placement.is_replicate() or placement.is_shard(target_dim) 

92 for placement in placements 

93 ) 

94 

95 

96def _orig_layout_is_supported( 

97 orig_placements: tuple[Any, ...], 

98 target_dim: int, 

99) -> bool: 

100 if not _placements_match_target_dim_only(orig_placements, target_dim): 

101 return False 

102 return sum( 

103 placement.is_shard(target_dim) for placement in orig_placements 

104 ) == 1 

105 

106 

107def _current_strided_layout_is_supported( 

108 placements: tuple[Any, ...], 

109 target_dim: int, 

110) -> bool: 

111 if not _placements_match_target_dim_only(placements, target_dim): 

112 return False 

113 if sum(placement.is_shard() for placement in placements) != 2: 

114 return False 

115 

116 strided_placements = [ 

117 placement for placement in placements if isinstance(placement, StridedShard) 

118 ] 

119 if len(strided_placements) != 1: 

120 return False 

121 strided_placement = strided_placements[0] 

122 if strided_placement.dim != target_dim or strided_placement.split_factor <= 1: 

123 return False 

124 

125 plain_shards = [ 

126 placement 

127 for placement in placements 

128 if placement.is_shard(target_dim) and not isinstance(placement, StridedShard) 

129 ] 

130 return len(plain_shards) == 1 

131 

132 

133def supports_same_dim_strided_layout(hsdp_param: Any) -> bool: 

134 ctx = _resolve_same_dim_strided_context(hsdp_param) 

135 if ctx is None: 

136 return False 

137 if not _orig_layout_is_supported(ctx.orig_placements, ctx.target_dim): 

138 return False 

139 return _current_strided_layout_is_supported(ctx.placements, ctx.target_dim) 

140 

141 

142def _resolve_unpacked_shape( 

143 hsdp_param: Optional[Any], 

144 local_tensor: torch.Tensor, 

145) -> torch.Size: 

146 if hsdp_param is not None and getattr(hsdp_param, "_orig_size", None) is not None: 

147 return torch.Size(getattr(hsdp_param, "_orig_size")) 

148 return torch.Size(local_tensor.size()) 

149 

150 

151def _get_packed_tensor_shape( 

152 unpacked_shape: torch.Size, 

153 shard_dim: int, 

154 world_size: int, 

155) -> torch.Size: 

156 if world_size == 1 or shard_dim == 0: 

157 return unpacked_shape 

158 packed_tensor_shape = list(unpacked_shape) 

159 packed_tensor_shape[0] *= world_size 

160 packed_tensor_shape[shard_dim] //= world_size 

161 return torch.Size(packed_tensor_shape) 

162 

163 

164def build_rs_plan( 

165 hsdp_param: Optional[Any], 

166 local_tensor: torch.Tensor, 

167 world_size: int, 

168 *, 

169 shard_dim: Optional[int] = None, 

170) -> ReduceScatterPlan: 

171 """Build the V1 reduce-scatter packing plan for a local gradient tensor.""" 

172 

173 if world_size <= 0: 

174 raise ValueError(f"world_size must be positive, but got {world_size}") 

175 

176 resolved_shard_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", shard_dim) 

177 if resolved_shard_dim is None: 

178 raise ValueError("build_rs_plan requires either hsdp_param or shard_dim") 

179 unpacked_shape = _resolve_unpacked_shape(hsdp_param, local_tensor) 

180 if resolved_shard_dim < 0 or resolved_shard_dim >= len(unpacked_shape): 

181 raise ValueError( 

182 f"Invalid shard dim {resolved_shard_dim} for tensor shape {tuple(unpacked_shape)}" 

183 ) 

184 if world_size == 1: 

185 if not local_tensor.is_contiguous(): 

186 raise NotImplementedError( 

187 "reduce_scatter_grad currently expects contiguous local gradients before packing." 

188 ) 

189 return ReduceScatterPlan( 

190 pack_kind="identity_dim0", 

191 shard_dim=resolved_shard_dim, 

192 world_size=world_size, 

193 packed_shape=torch.Size((1, math.prod(unpacked_shape))), 

194 packed_tensor_shape=unpacked_shape, 

195 unpacked_shape=unpacked_shape, 

196 ) 

197 if local_tensor.dim() == 0: 

198 raise NotImplementedError("reduce_scatter_grad does not support scalar gradients.") 

199 if unpacked_shape[resolved_shard_dim] % world_size != 0: 

200 raise NotImplementedError( 

201 f"reduce_scatter_grad currently only supports even sharding on dim={resolved_shard_dim}." 

202 ) 

203 if not local_tensor.is_contiguous(): 

204 raise NotImplementedError( 

205 "reduce_scatter_grad currently expects contiguous local gradients before packing." 

206 ) 

207 

208 pack_kind: Literal[ 

209 "identity_dim0", 

210 "same_dim_strided_identity_dim0", 

211 "chunk_cat_non_dim0", 

212 ] = "identity_dim0" 

213 if hsdp_param is not None and _has_strided_shard_layout(hsdp_param): 

214 if not supports_same_dim_strided_layout(hsdp_param): 

215 raise NotImplementedError( 

216 "reduce_scatter_grad only supports same-dim StridedShard layouts " 

217 "that restore a single contiguous TP-local shard on the fully_shard dimension." 

218 ) 

219 if resolved_shard_dim == 0: 

220 pack_kind = "same_dim_strided_identity_dim0" 

221 else: 

222 pack_kind = "chunk_cat_non_dim0" 

223 elif resolved_shard_dim != 0: 

224 pack_kind = "chunk_cat_non_dim0" 

225 

226 packed_tensor_shape = _get_packed_tensor_shape( 

227 unpacked_shape, 

228 resolved_shard_dim, 

229 world_size, 

230 ) 

231 total_numel = math.prod(unpacked_shape) 

232 

233 return ReduceScatterPlan( 

234 pack_kind=pack_kind, 

235 shard_dim=resolved_shard_dim, 

236 world_size=world_size, 

237 packed_shape=torch.Size((world_size, total_numel // world_size)), 

238 packed_tensor_shape=packed_tensor_shape, 

239 unpacked_shape=unpacked_shape, 

240 ) 

241 

242 

243def pack_for_reduce_scatter( 

244 local_tensor: torch.Tensor, 

245 plan: ReduceScatterPlan, 

246) -> torch.Tensor: 

247 """Pack one local gradient into the row-major reduce-scatter layout.""" 

248 

249 if plan.pack_kind not in ( 

250 "identity_dim0", 

251 "same_dim_strided_identity_dim0", 

252 "chunk_cat_non_dim0", 

253 ): 

254 raise NotImplementedError(f"Unsupported reduce-scatter pack kind: {plan.pack_kind}") 

255 if not local_tensor.is_contiguous(): 

256 raise NotImplementedError( 

257 "reduce_scatter_grad currently expects contiguous local gradients before packing." 

258 ) 

259 if local_tensor.size() != plan.unpacked_shape: 

260 raise AssertionError( 

261 "pack_for_reduce_scatter expects the unsharded local tensor shape to match " 

262 f"plan.unpacked_shape, but got {tuple(local_tensor.size())} and " 

263 f"{tuple(plan.unpacked_shape)}" 

264 ) 

265 if plan.pack_kind == "chunk_cat_non_dim0": 

266 chunks = torch.chunk(local_tensor, plan.world_size, dim=plan.shard_dim) 

267 packed_tensor = torch.cat(chunks, dim=0) 

268 return packed_tensor.contiguous().view(plan.packed_shape) 

269 return local_tensor.view(plan.packed_shape) 

270 

271 

272def unpack_from_all_gather( 

273 full_packed: torch.Tensor, 

274 plan: ReduceScatterPlan, 

275) -> torch.Tensor: 

276 """Inverse of the V1 reduce-scatter packing plan for all-gather outputs.""" 

277 

278 if plan.pack_kind not in ( 

279 "identity_dim0", 

280 "same_dim_strided_identity_dim0", 

281 "chunk_cat_non_dim0", 

282 ): 

283 raise NotImplementedError(f"Unsupported all-gather unpack kind: {plan.pack_kind}") 

284 packed_tensor = full_packed.view(plan.packed_tensor_shape) 

285 if plan.pack_kind == "chunk_cat_non_dim0": 

286 chunks = torch.chunk(packed_tensor, plan.world_size, dim=0) 

287 return torch.cat(chunks, dim=plan.shard_dim).contiguous() 

288 return packed_tensor.view(plan.unpacked_shape) 

289 

290 

291__all__ = [ 

292 "ReduceScatterPlan", 

293 "build_rs_plan", 

294 "pack_for_reduce_scatter", 

295 "unpack_from_all_gather", 

296 "supports_same_dim_strided_layout", 

297]