Coverage for hyper_parallel / core / random.py: 77%

185 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright (c) Meta Platforms, Inc. and affiliates 

2"""RNG state management for distributed tensor operations. 

3 

4Provides utilities for tracking and synchronizing random number generator states 

5across multiple devices in distributed training scenarios. 

6""" 

7import contextlib 

8from logging import getLogger 

9import typing 

10from typing import Optional 

11import functools 

12import operator 

13 

14from hyper_parallel.core.placement_types import Shard 

15from hyper_parallel.platform import get_platform 

16 

17platform = get_platform() 

18DTensorBase = platform.DTensorBase 

19Tensor = platform.tensor 

20 

21logger = getLogger(__name__) 

22 

23__all__ = [ 

24 "is_rng_supported_mesh", 

25 "OffsetBasedRNGTracker", 

26] 

27 

28_rng_tracker: Optional["_RNGStateTracker"] = None 

29 

30 

31def is_rng_supported_mesh() -> bool: 

32 """Check if the device mesh supports DTensor random operations. 

33 

34 Currently, DTensor random operations are only supported on CUDA and CUDA-like 

35 devices. Users should call this function before using DTensor random APIs to 

36 verify compatibility. 

37 

38 Returns: 

39 bool: ``True`` if the device mesh supports DTensor random operations, 

40 ``False`` otherwise. 

41 """ 

42 device_handle = platform.get_device_handle() 

43 if device_handle and hasattr(device_handle, "set_rng_state"): 

44 return True 

45 return False 

46 

47 

48class _PhiloxState: 

49 """ 

50 Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state, 

51 which for some reason is actually exposed as a size-16 uint8 tensor. 

52 

53 The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator. 

54 """ 

55 

56 def __init__(self, state: Tensor): 

57 self._state = state.to("cpu") 

58 

59 @property 

60 def state(self): 

61 return self._state 

62 

63 @property 

64 def offset(self) -> int: 

65 return int(self._state[8:].view(dtype=platform.tensor_dtype.int64).item()) 

66 

67 @offset.setter 

68 def offset(self, offset: int) -> None: 

69 offset_tensor = Tensor([offset], dtype=platform.tensor_dtype.uint64).view( 

70 platform.tensor_dtype.uint8 

71 ) # device? 

72 self._state[8:] = offset_tensor 

73 

74 @property 

75 def seed(self) -> int: 

76 return int(self._state[:8].view(dtype=platform.tensor_dtype.uint64).item()) 

77 

78 @seed.setter 

79 def seed(self, seed: int) -> None: 

80 seed_tensor = Tensor([seed], dtype=platform.tensor_dtype.uint64).view( 

81 platform.tensor_dtype.uint8 

82 )# device 

83 self._state[:8] = seed_tensor 

84 

85 

86class _RNGStateTracker: 

87 """ 

88 Tracks and manages RNG states for DTensor random operations. 

89 

90 Maintains a mapping from operation tags to RNG state tensors (ByteTensor), 

91 providing standardized interfaces for state access and modification. 

92 

93 The core method `_distribute_region` establishes the proper RNG context 

94 when DTensor executes random operators across distributed devices. 

95 """ 

96 

97 def __init__(self, device): 

98 self._device = device 

99 self._device_handle = platform.get_device_handle() 

100 if not self._device_handle: 

101 raise RuntimeError( 

102 f"{self.__class__.__name__} instantiation requires the presence of " 

103 ) 

104 self._use_distribute_region = True 

105 

106 @property 

107 def distribute_region_enabled(self) -> bool: 

108 return self._use_distribute_region 

109 

110 @distribute_region_enabled.setter 

111 def distribute_region_enabled(self, value) -> None: 

112 self._use_distribute_region = value 

113 

114 def _distribute_region( 

115 self, device_mesh, placements, global_shape, generator = None 

116 ): 

117 pass 

118 

119 def _manual_seed(self, parallel_seed: int) -> None: 

120 pass 

121 

122 

123class OffsetBasedRNGTracker(_RNGStateTracker): 

124 """ 

125 This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states 

126 should be shared and synchronized among all ranks to respect the semantics of DTensor 

127 random operators. 

128 """ 

129 

130 def __init__( 

131 self, 

132 run_state_sync: bool = True, 

133 ): 

134 super().__init__(_resolve_device()) 

135 rng_state = self._get_device_state() 

136 if run_state_sync: 

137 # synchronize RNG state using rank 0's current one 

138 platform.broadcast(rng_state, 0) 

139 my_rng_state = self._get_device_state() 

140 if not all(my_rng_state == rng_state): 

141 logger.warning( 

142 "DTensor is synchronizing RNG states of every rank with the state from rank 0. " 

143 "This behavior is deprecated. " 

144 "Please call `manual_seed()` on every rank that participates in SPMD DTensor Operations with " 

145 "the same seed. If using Pipeline Parallelism, each pipelining state would use a different seed, " 

146 "but all ranks belonging to one pipeline stage would use the same seed." 

147 ) 

148 self._set_device_state(rng_state) 

149 

150 def _get_device_state(self): 

151 rng_state = self._device_handle.get_rng_state().to(self._device) 

152 return rng_state 

153 

154 def _set_device_state(self, state: Tensor): 

155 # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state` 

156 # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug 

157 # for now, we just convert back to cpu here to make sure it always works. 

158 self._device_handle.set_rng_state(state.to("cpu")) 

159 

160 @contextlib.contextmanager 

161 def _distribute_region( 

162 self, device_mesh, placements, global_shape, generator = None 

163 ): 

164 

165 # regular (non-LocalTensor) mode 

166 if generator is not None: 

167 # This is a little hacky, but for any user-passed generator, we store its state under a unique key, 

168 # not because we need to keep a copy of it but because its the easiest way to make it work with the 

169 # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. 

170 state = _PhiloxState(generator.get_state()) 

171 else: 

172 state = _PhiloxState(self._get_device_state()) 

173 

174 if self.distribute_region_enabled: 

175 old_offset = state.offset 

176 self._set_pre_op_offset(state, device_mesh, placements, global_shape) 

177 with fork_rng( 

178 devices=[self._device], device_type=platform.device_type() 

179 ): 

180 self._device_handle.set_rng_state(state.state) 

181 try: 

182 yield # execute the region code 

183 finally: 

184 # update offset to synchronize among ranks 

185 self._set_post_op_offset(state, device_mesh, old_offset) 

186 

187 else: 

188 yield 

189 

190 if generator is not None: 

191 # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future 

192 # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates 

193 # the seed value in their rng and uses it with DTensor again, we always use the latest value 

194 generator.set_state(state.state) 

195 else: 

196 self._set_device_state(state.state) 

197 

198 def _set_pre_op_offset(self, state: _PhiloxState, device_mesh, placements, global_shape) -> None: 

199 """Set the starting random number generator (RNG) offset for the local shard 

200 on the current process before operation execution.The offset value begins from 

201 the current accumulated position and increments by the local shard size until 

202 covering the total elements of the global distributed tensor. Multiple processes 

203 holding replicas of the same shard will share identical starting offset values. 

204 

205 Args: 

206 state (`Tensor`): The generator state to modify 

207 device_mesh (DeviceMesh): The device mesh describing the device topology. 

208 placements (Sequence[Placement]): The placement strategy for each mesh dimension. 

209 Each element should be a Placement object (Shard, Replicate, Partial, etc.). 

210 global_shape: input global shape 

211 

212 Returns: 

213 None 

214 

215 .. warning:: 

216 The current implementation does not consider memory layout contiguity. 

217 

218 Example: 

219 take a DTensor of shape [8, 16] as an example. Assume that the DTensor 

220 is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), 

221 and the mesh is: 

222 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 

223 ``mesh.get_coordinate()`` provides the coordinate of the current rank 

224 in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). 

225 

226 Another concept to introduce besides rank coordinate is shard coordinate. 

227 Each rank holds a local shard of the DTensor. In the example, the DTensor 

228 is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and 

229 rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. 

230 That being said, the local shard on rank 0 and rank 2 correspond to the same 

231 shard of the DTensor. To denote each DTensor shard, we use a shard coordinate 

232 (in the example, it will be a tuple (i, j) where shard (i, j) has the slice 

233 DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). 

234 

235 Once we have rank coordinate and shard coordinate, we can calculate on each rank 

236 what shard of the DTensor the rank holds, with the help of dim_map. The dim_map 

237 of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord 

238 (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). 

239 Following this calculation, 

240 rank 0 and rank 2 holds the shard of coord (0, 0); 

241 rank 1 and rank 3 holds the shard of coord (0, 1); 

242 rank 4 and rank 6 holds the shard of coord (1, 0); 

243 rank 5 and rank 7 holds the shard of coord (1, 1); 

244 

245 The last value to calculate before obtaining the starting offset is the shard linear index. 

246 The starting offset for each rank will be its shard_linear_index * local_tensor_numel. 

247 """ 

248 mesh = device_mesh 

249 mesh_coordinate = mesh.get_coordinate() 

250 

251 # Compute shard index and total number of shards on each tensor dim 

252 shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( 

253 mesh_coordinate, device_mesh, placements 

254 ) 

255 

256 # compute shard linear index 

257 shard_linear_idx = self._calc_shard_linear_idx( 

258 shard_idx_by_dim, total_num_shards_by_dim 

259 ) 

260 

261 # compute starting offset using the first shard's size 

262 local_size_on_rank_0 = _calc_first_shard_size(device_mesh, placements, global_shape) 

263 

264 local_size = functools.reduce(operator.mul, local_size_on_rank_0, 1) 

265 

266 # get current RNG offset 

267 current_offset = state.offset 

268 

269 offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 

270 state.offset = current_offset + offset_incr 

271 

272 def _set_post_op_offset( 

273 self, state: _PhiloxState, device_mesh, old_offset: int 

274 ) -> None: 

275 """Sets the RNG to a synchronized state after running the local random op. 

276 Restores the random number generator to a globally consistent state following 

277 local shard execution. Each process must advance its offset by the total element 

278 count of the distributed tensor, measured from the offset value recorded before 

279 the operation began. 

280 

281 Args: 

282 state (`Tensor`): The generator state to modify. 

283 device_mesh (DeviceMesh): The device mesh describing the device topology. 

284 

285 Returns: 

286 None 

287 """ 

288 dtensor_shape = device_mesh.mesh_shape 

289 

290 numel = functools.reduce(operator.mul, dtensor_shape, 1) 

291 numel = (numel + 3) // 4 * 4 

292 state.offset = old_offset + numel 

293 

294 def _calc_shard_linear_idx( 

295 self, shard_coord: list[int], shard_size: list[int] 

296 ) -> int: 

297 return _calc_shard_linear_idx(shard_coord, shard_size) 

298 

299 

300def _calc_first_shard_size(device_mesh, placements, global_shape) -> list[int]: 

301 """Calculate the size of the first shard on rank 0. 

302 

303 Args: 

304 device_mesh: The device mesh describing the device topology. 

305 placements: Sequence of Placement objects (Shard, Replicate, etc.). 

306 global_shape: input global shape 

307 

308 Returns: 

309 list[int]: Shape of rank 0's local shard. 

310 """ 

311 local_size_on_rank_0 = list(global_shape) 

312 for idx, placement in enumerate(placements): 

313 if isinstance(placement, Shard): 

314 mesh_dim_size = device_mesh.size(idx) 

315 shard_dim = placement.dim 

316 local_size_on_rank_0[shard_dim], _ = local_shard_size_and_offset( 

317 device_mesh.mesh_shape[shard_dim], 

318 mesh_dim_size, 

319 0, 

320 ) 

321 return local_size_on_rank_0 

322 

323 

324def _calc_shard_info( 

325 mesh_coordinate, device_mesh, placements 

326): 

327 """Calculate shard information for a specific rank.""" 

328 mesh_size = device_mesh.mesh_shape 

329 # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP 

330 # case. Replace the custom logic with dim_map once we support it. 

331 dim_map = [-1] * device_mesh.ndim 

332 for i, placement in enumerate(placements): 

333 if isinstance(placement, Shard): 

334 shard_dim = placement.dim 

335 if dim_map[shard_dim] == -1: 

336 dim_map[shard_dim] = [i] 

337 else: 

338 mesh_dim_list = dim_map[shard_dim] 

339 assert isinstance(mesh_dim_list, list) 

340 mesh_dim_list.append(i) 

341 

342 # Compute shard coordinate: 

343 # The coordinate on each tensor dim is a tuple (idx, range) 

344 # If a DTensor is partitioned on its dim i into n shards, and the current rank 

345 # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i 

346 assert mesh_coordinate is not None 

347 shard_idx_by_dim = [] 

348 total_num_shards_by_dim = [] # total number of shards on each tensor dim 

349 for mesh_dim in dim_map: 

350 shard_idx = 0 

351 total_num_shards = 1 

352 # the tensor dim is sharded on more than 1 mesh dim 

353 if isinstance(mesh_dim, list): 

354 rank_coord = [mesh_coordinate[d] for d in mesh_dim] 

355 num_shards = [mesh_size[d] for d in mesh_dim] 

356 # compute the shard idx and total number of shards 

357 for idx, size in zip(rank_coord, num_shards): 

358 shard_idx = shard_idx * size + idx 

359 total_num_shards *= size 

360 

361 shard_idx_by_dim.append(shard_idx) 

362 total_num_shards_by_dim.append(total_num_shards) 

363 return shard_idx_by_dim, total_num_shards_by_dim 

364 

365 

366def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: 

367 # compute shard linear index 

368 shard_linear_idx = 0 

369 shard_coord_stride = 1 

370 for idx, size in zip(reversed(shard_coord), reversed(shard_size)): 

371 shard_linear_idx += idx * shard_coord_stride 

372 shard_coord_stride *= size 

373 

374 return shard_linear_idx 

375 

376 

377def _resolve_device(): 

378 device_handle = platform.get_device_handle() 

379 device_idx = platform.get_rank() % platform.device_count(device_handle) 

380 

381 def get_device(device_idx): 

382 return platform.device(device_idx) 

383 

384 return get_device(device_idx) 

385 

386def local_shard_size_and_offset( 

387 curr_local_size: int, 

388 num_chunks: int, 

389 rank, 

390): 

391 """ 

392 Given the size of the current local tensor (which may already be sharded on some dimensions), 

393 computes the new local shard size and offset given the desired number of chunks 

394 (num_chunks is generally equal to the size of the current sharding dim). 

395 

396 Note: new local shard offset is relative to the current sharded tensor, not the global tensor. 

397 See `_utils.compute_local_shape_and_global_offset` for computing global offset. 

398 

399 Returns (new local shard size, offset) 

400 

401 """ 

402 # Compute the chunk size inline 

403 if curr_local_size % num_chunks == 0: 

404 full_chunk_size = curr_local_size // num_chunks 

405 shard_starting_idx = full_chunk_size * rank 

406 return full_chunk_size, shard_starting_idx 

407 

408 # uneven sharding case 

409 full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks 

410 shard_starting_idx = full_chunk_size * rank 

411 

412 if curr_local_size < shard_starting_idx: 

413 return 0, typing.cast(int, curr_local_size) 

414 local_shard_size = ( 

415 min(curr_local_size, shard_starting_idx + full_chunk_size) 

416 - shard_starting_idx 

417 ) 

418 return local_shard_size, shard_starting_idx 

419 

420 

421_fork_rng_warned_already = False 

422 

423@contextlib.contextmanager 

424def fork_rng( 

425 devices=None, 

426 enabled=True, 

427 device_type="npu", 

428): 

429 """ 

430 Forks the RNG, so that when you return, the RNG is reset 

431 to the state that it was previously in. 

432 

433 Args: 

434 devices (iterable of Device IDs): devices for which to fork 

435 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates 

436 on all devices, but will emit a warning if your machine has a lot 

437 of devices, since this function will run very slowly in that case. 

438 If you explicitly specify devices, this warning will be suppressed 

439 enabled (bool): if ``False``, the RNG is not forked. This is a convenience 

440 argument for easily disabling the context manager without having 

441 to delete it and unindent your Python code under it. 

442 device_type (str): device type str, default is `npu`. As for supported device, 

443 see details in :ref:`accelerator<accelerators>` 

444 """ 

445 

446 device_mod = platform.get_device_handle() 

447 if device_mod is None: 

448 raise RuntimeError( 

449 f"{platform} has no module of `{device_type}`, you should register " 

450 ) 

451 global _fork_rng_warned_already 

452 

453 if not enabled: 

454 yield 

455 return 

456 

457 if devices is None: 

458 num_devices = platform.device_count(device_mod) 

459 if num_devices > 1 and not _fork_rng_warned_already: 

460 _fork_rng_warned_already = True 

461 devices = list(range(num_devices)) 

462 else: 

463 # Protect against user passing us a generator; we need to traverse this 

464 # multiple times but a generator will be exhausted upon first traversal 

465 devices = list(devices) 

466 

467 cpu_rng_state = platform.get_rng_state() 

468 device_rng_states = [platform.get_rng_state(device, device_mod) for device in devices] 

469 

470 try: 

471 yield 

472 finally: 

473 platform.set_rng_state(cpu_rng_state) 

474 for device, device_rng_state in zip(devices, device_rng_states): 

475 platform.set_rng_state(device_rng_state, device, device_mod)