Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / platform / mindspore / fully_shard / pack_utils.py: 25%

134 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Packing helpers for MindSpore fully_shard communication buffers.""" 

16 

17from __future__ import annotations 

18 

19import math 

20from dataclasses import dataclass 

21from typing import Any, Literal, Optional 

22 

23import mindspore as ms 

24 

25from hyper_parallel.core.dtensor.placement_types import StridedShard 

26 

27 

28@dataclass(frozen=True) 

29class ReduceScatterPlan: 

30 """Describe how local tensors map to packed communication layouts.""" 

31 

32 pack_kind: Literal[ 

33 "identity_dim0", 

34 "same_dim_strided_identity_dim0", 

35 "chunk_cat_non_dim0", 

36 ] 

37 shard_dim: int 

38 world_size: int 

39 packed_shape: tuple[int, ...] 

40 packed_tensor_shape: tuple[int, ...] 

41 unpacked_shape: tuple[int, ...] 

42 

43 

44@dataclass(frozen=True) 

45class _SameDimStridedLayoutContext: 

46 target_dim: int 

47 shard_mesh_dim: int 

48 placements: tuple[Any, ...] 

49 orig_placements: tuple[Any, ...] 

50 

51 

52def _shape_tuple(shape) -> tuple[int, ...]: 

53 return tuple(int(dim) for dim in shape) 

54 

55 

56def _has_strided_shard_layout(hsdp_param: Any) -> bool: 

57 placements = getattr(hsdp_param, "_spmd_placements", ()) or () 

58 return any(isinstance(placement, StridedShard) for placement in placements) 

59 

60 

61def _resolve_same_dim_strided_context( 

62 hsdp_param: Any, 

63) -> Optional[_SameDimStridedLayoutContext]: 

64 if not _has_strided_shard_layout(hsdp_param): 

65 return None 

66 if not getattr(hsdp_param, "uses_param_shard", False): 

67 return None 

68 if not getattr(hsdp_param, "_orig_param_is_dtensor", False): 

69 return None 

70 target_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", None) 

71 if target_dim is None: 

72 return None 

73 shard_mesh_dim = getattr(hsdp_param, "_spmd_shard_mesh_dim", None) 

74 placements = tuple(getattr(hsdp_param, "_spmd_placements", ()) or ()) 

75 if shard_mesh_dim is None or shard_mesh_dim >= len(placements): 

76 return None 

77 if not isinstance(placements[shard_mesh_dim], StridedShard): 

78 return None 

79 orig_placements = getattr(hsdp_param, "_orig_dtensor_placements", None) 

80 if orig_placements is None: 

81 return None 

82 return _SameDimStridedLayoutContext( 

83 target_dim=target_dim, 

84 shard_mesh_dim=shard_mesh_dim, 

85 placements=placements, 

86 orig_placements=tuple(orig_placements), 

87 ) 

88 

89 

90def _placements_match_target_dim_only( 

91 placements: tuple[Any, ...], 

92 target_dim: int, 

93) -> bool: 

94 return all( 

95 placement.is_replicate() or placement.is_shard(target_dim) 

96 for placement in placements 

97 ) 

98 

99 

100def _orig_layout_is_supported( 

101 orig_placements: tuple[Any, ...], 

102 target_dim: int, 

103) -> bool: 

104 if not _placements_match_target_dim_only(orig_placements, target_dim): 

105 return False 

106 return sum( 

107 placement.is_shard(target_dim) for placement in orig_placements 

108 ) == 1 

109 

110 

111def _current_strided_layout_is_supported( 

112 placements: tuple[Any, ...], 

113 target_dim: int, 

114) -> bool: 

115 if not _placements_match_target_dim_only(placements, target_dim): 

116 return False 

117 if sum(placement.is_shard() for placement in placements) != 2: 

118 return False 

119 

120 strided_placements = [ 

121 placement for placement in placements if isinstance(placement, StridedShard) 

122 ] 

123 if len(strided_placements) != 1: 

124 return False 

125 strided_placement = strided_placements[0] 

126 if strided_placement.dim != target_dim or strided_placement.split_factor <= 1: 

127 return False 

128 

129 plain_shards = [ 

130 placement 

131 for placement in placements 

132 if placement.is_shard(target_dim) and not isinstance(placement, StridedShard) 

133 ] 

134 return len(plain_shards) == 1 

135 

136 

137def supports_same_dim_strided_layout(hsdp_param: Any) -> bool: 

138 ctx = _resolve_same_dim_strided_context(hsdp_param) 

139 if ctx is None: 

140 return False 

141 if not _orig_layout_is_supported(ctx.orig_placements, ctx.target_dim): 

142 return False 

143 return _current_strided_layout_is_supported(ctx.placements, ctx.target_dim) 

144 

145 

146def _resolve_unpacked_shape( 

147 hsdp_param: Optional[Any], 

148 local_tensor: ms.Tensor, 

149) -> tuple[int, ...]: 

150 if hsdp_param is not None and getattr(hsdp_param, "_orig_size", None) is not None: 

151 return _shape_tuple(getattr(hsdp_param, "_orig_size")) 

152 return _shape_tuple(local_tensor.shape) 

153 

154 

155def _get_packed_tensor_shape( 

156 unpacked_shape: tuple[int, ...], 

157 shard_dim: int, 

158 world_size: int, 

159) -> tuple[int, ...]: 

160 if world_size == 1 or shard_dim == 0: 

161 return unpacked_shape 

162 packed_tensor_shape = list(unpacked_shape) 

163 packed_tensor_shape[0] *= world_size 

164 packed_tensor_shape[shard_dim] //= world_size 

165 return tuple(packed_tensor_shape) 

166 

167 

168def build_rs_plan( 

169 hsdp_param: Optional[Any], 

170 local_tensor: ms.Tensor, 

171 world_size: int, 

172 *, 

173 shard_dim: Optional[int] = None, 

174) -> ReduceScatterPlan: 

175 """Build the V1 reduce-scatter packing plan for a local gradient tensor.""" 

176 

177 if world_size <= 0: 

178 raise ValueError(f"world_size must be positive, but got {world_size}") 

179 

180 resolved_shard_dim = getattr(getattr(hsdp_param, "hsdp_placement", None), "dim", shard_dim) 

181 if resolved_shard_dim is None: 

182 raise ValueError("build_rs_plan requires either hsdp_param or shard_dim") 

183 unpacked_shape = _resolve_unpacked_shape(hsdp_param, local_tensor) 

184 if resolved_shard_dim < 0 or resolved_shard_dim >= len(unpacked_shape): 

185 raise ValueError( 

186 f"Invalid shard dim {resolved_shard_dim} for tensor shape {tuple(unpacked_shape)}" 

187 ) 

188 if world_size == 1: 

189 if not local_tensor.is_contiguous(): 

190 raise NotImplementedError( 

191 "reduce_scatter_grad currently expects contiguous local gradients before packing." 

192 ) 

193 return ReduceScatterPlan( 

194 pack_kind="identity_dim0", 

195 shard_dim=resolved_shard_dim, 

196 world_size=world_size, 

197 packed_shape=(1, math.prod(unpacked_shape)), 

198 packed_tensor_shape=unpacked_shape, 

199 unpacked_shape=unpacked_shape, 

200 ) 

201 if len(local_tensor.shape) == 0: 

202 raise NotImplementedError("reduce_scatter_grad does not support scalar gradients.") 

203 if unpacked_shape[resolved_shard_dim] % world_size != 0: 

204 raise NotImplementedError( 

205 f"reduce_scatter_grad currently only supports even sharding on dim={resolved_shard_dim}." 

206 ) 

207 if not local_tensor.is_contiguous(): 

208 raise NotImplementedError( 

209 "reduce_scatter_grad currently expects contiguous local gradients before packing." 

210 ) 

211 

212 pack_kind: Literal[ 

213 "identity_dim0", 

214 "same_dim_strided_identity_dim0", 

215 "chunk_cat_non_dim0", 

216 ] = "identity_dim0" 

217 if hsdp_param is not None and _has_strided_shard_layout(hsdp_param): 

218 if not supports_same_dim_strided_layout(hsdp_param): 

219 raise NotImplementedError( 

220 "reduce_scatter_grad only supports same-dim StridedShard layouts " 

221 "that restore a single contiguous TP-local shard on the fully_shard dimension." 

222 ) 

223 if resolved_shard_dim == 0: 

224 pack_kind = "same_dim_strided_identity_dim0" 

225 else: 

226 pack_kind = "chunk_cat_non_dim0" 

227 elif resolved_shard_dim != 0: 

228 pack_kind = "chunk_cat_non_dim0" 

229 

230 packed_tensor_shape = _get_packed_tensor_shape( 

231 unpacked_shape, 

232 resolved_shard_dim, 

233 world_size, 

234 ) 

235 total_numel = math.prod(unpacked_shape) 

236 return ReduceScatterPlan( 

237 pack_kind=pack_kind, 

238 shard_dim=resolved_shard_dim, 

239 world_size=world_size, 

240 packed_shape=(world_size, total_numel // world_size), 

241 packed_tensor_shape=packed_tensor_shape, 

242 unpacked_shape=unpacked_shape, 

243 ) 

244 

245 

246def pack_for_reduce_scatter( 

247 local_tensor: ms.Tensor, 

248 plan: ReduceScatterPlan, 

249) -> ms.Tensor: 

250 """Pack one local gradient into the row-major reduce-scatter layout.""" 

251 

252 if plan.pack_kind not in ( 

253 "identity_dim0", 

254 "same_dim_strided_identity_dim0", 

255 "chunk_cat_non_dim0", 

256 ): 

257 raise NotImplementedError(f"Unsupported reduce-scatter pack kind: {plan.pack_kind}") 

258 if not local_tensor.is_contiguous(): 

259 raise NotImplementedError( 

260 "reduce_scatter_grad currently expects contiguous local gradients before packing." 

261 ) 

262 if _shape_tuple(local_tensor.shape) != plan.unpacked_shape: 

263 raise AssertionError( 

264 "pack_for_reduce_scatter expects the unsharded local tensor shape to match " 

265 f"plan.unpacked_shape, but got {tuple(local_tensor.shape)} and " 

266 f"{tuple(plan.unpacked_shape)}" 

267 ) 

268 if plan.pack_kind == "chunk_cat_non_dim0": 

269 chunks = ms.mint.chunk(local_tensor, plan.world_size, dim=plan.shard_dim) 

270 packed_tensor = ms.mint.cat(chunks, dim=0) 

271 return packed_tensor.contiguous().view(plan.packed_shape) 

272 return local_tensor.view(plan.packed_shape) 

273 

274 

275def unpack_from_all_gather( 

276 full_packed: ms.Tensor, 

277 plan: ReduceScatterPlan, 

278) -> ms.Tensor: 

279 """Inverse of the V1 reduce-scatter packing plan for all-gather outputs.""" 

280 

281 if plan.pack_kind not in ( 

282 "identity_dim0", 

283 "same_dim_strided_identity_dim0", 

284 "chunk_cat_non_dim0", 

285 ): 

286 raise NotImplementedError(f"Unsupported all-gather unpack kind: {plan.pack_kind}") 

287 packed_tensor = full_packed.view(plan.packed_tensor_shape) 

288 if plan.pack_kind == "chunk_cat_non_dim0": 

289 chunks = ms.mint.chunk(packed_tensor, plan.world_size, dim=0) 

290 return ms.mint.cat(chunks, dim=plan.shard_dim).contiguous() 

291 return packed_tensor.view(plan.unpacked_shape) 

292 

293 

294__all__ = [ 

295 "ReduceScatterPlan", 

296 "build_rs_plan", 

297 "pack_for_reduce_scatter", 

298 "unpack_from_all_gather", 

299 "supports_same_dim_strided_layout", 

300]