Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / random.py: 28%

187 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright (c) Meta Platforms, Inc. and affiliates 

2"""RNG state management for distributed tensor operations. 

3 

4Provides utilities for tracking and synchronizing random number generator states 

5across multiple devices in distributed training scenarios. 

6""" 

7 

8__all__ = [ 

9 "is_rng_supported_mesh", 

10 "OffsetBasedRNGTracker", 

11] 

12 

13import contextlib 

14from logging import getLogger 

15import typing 

16from typing import Optional 

17import functools 

18import operator 

19 

20from hyper_parallel.core.dtensor.placement_types import Shard 

21from hyper_parallel.platform import get_platform 

22 

23platform = get_platform() 

24DTensorBase = platform.DTensorBase 

25Tensor = platform.tensor 

26 

27logger = getLogger(__name__) 

28 

29_rng_tracker: Optional["_RNGStateTracker"] = None 

30 

31 

32def is_rng_supported_mesh() -> bool: 

33 """Check if the device mesh supports DTensor random operations. 

34 

35 Currently, DTensor random operations are only supported on CUDA and CUDA-like 

36 devices. Users should call this function before using DTensor random APIs to 

37 verify compatibility. 

38 

39 Returns: 

40 bool: ``True`` if the device mesh supports DTensor random operations, 

41 ``False`` otherwise. 

42 """ 

43 device_handle = platform.get_device_handle() 

44 if device_handle and hasattr(device_handle, "set_rng_state"): 

45 return True 

46 return False 

47 

48 

49class _PhiloxState: 

50 """ 

51 Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state, 

52 which for some reason is actually exposed as a size-16 uint8 tensor. 

53 

54 The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator. 

55 """ 

56 

57 def __init__(self, state: Tensor): 

58 self._state = state.to("cpu") 

59 

60 @property 

61 def state(self): 

62 return self._state 

63 

64 @property 

65 def offset(self) -> int: 

66 return int(self._state[8:].view(dtype=platform.tensor_dtype.int64).item()) 

67 

68 @offset.setter 

69 def offset(self, offset: int) -> None: 

70 offset_tensor = Tensor([offset], dtype=platform.tensor_dtype.uint64).view( 

71 platform.tensor_dtype.uint8 

72 ) # device? 

73 self._state[8:] = offset_tensor 

74 

75 @property 

76 def seed(self) -> int: 

77 return int(self._state[:8].view(dtype=platform.tensor_dtype.uint64).item()) 

78 

79 @seed.setter 

80 def seed(self, seed: int) -> None: 

81 seed_tensor = Tensor([seed], dtype=platform.tensor_dtype.uint64).view( 

82 platform.tensor_dtype.uint8 

83 )# device 

84 self._state[:8] = seed_tensor 

85 

86 

87class _RNGStateTracker: 

88 """ 

89 Tracks and manages RNG states for DTensor random operations. 

90 

91 Maintains a mapping from operation tags to RNG state tensors (ByteTensor), 

92 providing standardized interfaces for state access and modification. 

93 

94 The core method `_distribute_region` establishes the proper RNG context 

95 when DTensor executes random operators across distributed devices. 

96 """ 

97 

98 def __init__(self, device): 

99 self._device = device 

100 self._device_handle = platform.get_device_handle() 

101 if not self._device_handle: 

102 raise RuntimeError( 

103 f"{self.__class__.__name__} instantiation requires the presence of " 

104 ) 

105 self._use_distribute_region = True 

106 

107 @property 

108 def distribute_region_enabled(self) -> bool: 

109 return self._use_distribute_region 

110 

111 @distribute_region_enabled.setter 

112 def distribute_region_enabled(self, value) -> None: 

113 self._use_distribute_region = value 

114 

115 def _distribute_region( 

116 self, device_mesh, placements, global_shape, generator = None 

117 ): 

118 pass 

119 

120 def _manual_seed(self, parallel_seed: int) -> None: 

121 pass 

122 

123 

124class OffsetBasedRNGTracker(_RNGStateTracker): 

125 """ 

126 This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states 

127 should be shared and synchronized among all ranks to respect the semantics of DTensor 

128 random operators. 

129 """ 

130 

131 def __init__( 

132 self, 

133 run_state_sync: bool = True, 

134 ): 

135 super().__init__(_resolve_device()) 

136 rng_state = self._get_device_state() 

137 if run_state_sync: 

138 # synchronize RNG state using rank 0's current one 

139 platform.broadcast(rng_state, 0) 

140 my_rng_state = self._get_device_state() 

141 if not all(my_rng_state == rng_state): 

142 logger.warning( 

143 "DTensor is synchronizing RNG states of every rank with the state from rank 0. " 

144 "This behavior is deprecated. " 

145 "Please call `manual_seed()` on every rank that participates in SPMD DTensor Operations with " 

146 "the same seed. If using Pipeline Parallelism, each pipelining state would use a different seed, " 

147 "but all ranks belonging to one pipeline stage would use the same seed." 

148 ) 

149 self._set_device_state(rng_state) 

150 

151 def _get_device_state(self): 

152 rng_state = self._device_handle.get_rng_state().to(self._device) 

153 return rng_state 

154 

155 def _set_device_state(self, state: Tensor): 

156 # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state` 

157 # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug 

158 # for now, we just convert back to cpu here to make sure it always works. 

159 self._device_handle.set_rng_state(state.to("cpu")) 

160 

161 @contextlib.contextmanager 

162 def _distribute_region( 

163 self, device_mesh, placements, global_shape, generator = None 

164 ): 

165 

166 # regular (non-LocalTensor) mode 

167 if generator is not None: 

168 # This is a little hacky, but for any user-passed generator, we store its state under a unique key, 

169 # not because we need to keep a copy of it but because its the easiest way to make it work with the 

170 # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region. 

171 state = _PhiloxState(generator.get_state()) 

172 else: 

173 state = _PhiloxState(self._get_device_state()) 

174 

175 if self.distribute_region_enabled: 

176 old_offset = state.offset 

177 self._set_pre_op_offset(state, device_mesh, placements, global_shape) 

178 with fork_rng( 

179 devices=[self._device], device_type=platform.device_type() 

180 ): 

181 self._device_handle.set_rng_state(state.state) 

182 try: 

183 yield # execute the region code 

184 finally: 

185 # update offset to synchronize among ranks 

186 self._set_post_op_offset(state, global_shape, old_offset) 

187 

188 else: 

189 yield 

190 

191 if generator is not None: 

192 # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future 

193 # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates 

194 # the seed value in their rng and uses it with DTensor again, we always use the latest value 

195 generator.set_state(state.state) 

196 else: 

197 self._set_device_state(state.state) 

198 

199 def compute_offset_incr(self, device_mesh, placements, global_shape) -> int: 

200 """Compute the per-shard RNG offset increment for the current rank. 

201 

202 Based on the shard linear index and local shard size, computes how much to 

203 advance the offset so that each shard gets a unique portion of the random stream. 

204 

205 Args: 

206 device_mesh (DeviceMesh): The device mesh describing the device topology. 

207 placements (Sequence[Placement]): The placement strategy for each mesh dimension. 

208 global_shape: input global shape 

209 

210 Returns: 

211 int: The offset increment, 4-byte aligned. 

212 """ 

213 mesh_coordinate = device_mesh.get_coordinate() 

214 shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( 

215 mesh_coordinate, device_mesh, placements 

216 ) 

217 shard_linear_idx = self._calc_shard_linear_idx( 

218 shard_idx_by_dim, total_num_shards_by_dim 

219 ) 

220 local_size_on_rank_0 = _calc_first_shard_size(device_mesh, placements, global_shape) 

221 local_size = functools.reduce(operator.mul, local_size_on_rank_0, 1) 

222 return (shard_linear_idx * local_size + 3) // 4 * 4 

223 

224 def _set_pre_op_offset(self, state: _PhiloxState, device_mesh, placements, global_shape) -> None: 

225 """Set the starting random number generator (RNG) offset for the local shard 

226 on the current process before operation execution.The offset value begins from 

227 the current accumulated position and increments by the local shard size until 

228 covering the total elements of the global distributed tensor. Multiple processes 

229 holding replicas of the same shard will share identical starting offset values. 

230 

231 Args: 

232 state (`Tensor`): The generator state to modify 

233 device_mesh (DeviceMesh): The device mesh describing the device topology. 

234 placements (Sequence[Placement]): The placement strategy for each mesh dimension. 

235 Each element should be a Placement object (Shard, Replicate, Partial, etc.). 

236 global_shape: input global shape 

237 

238 Returns: 

239 None 

240 

241 .. warning:: 

242 The current implementation does not consider memory layout contiguity. 

243 

244 Example: 

245 take a DTensor of shape [8, 16] as an example. Assume that the DTensor 

246 is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), 

247 and the mesh is: 

248 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] 

249 ``mesh.get_coordinate()`` provides the coordinate of the current rank 

250 in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). 

251 

252 Another concept to introduce besides rank coordinate is shard coordinate. 

253 Each rank holds a local shard of the DTensor. In the example, the DTensor 

254 is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and 

255 rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. 

256 That being said, the local shard on rank 0 and rank 2 correspond to the same 

257 shard of the DTensor. To denote each DTensor shard, we use a shard coordinate 

258 (in the example, it will be a tuple (i, j) where shard (i, j) has the slice 

259 DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). 

260 

261 Once we have rank coordinate and shard coordinate, we can calculate on each rank 

262 what shard of the DTensor the rank holds, with the help of dim_map. The dim_map 

263 of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord 

264 (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). 

265 Following this calculation, 

266 rank 0 and rank 2 holds the shard of coord (0, 0); 

267 rank 1 and rank 3 holds the shard of coord (0, 1); 

268 rank 4 and rank 6 holds the shard of coord (1, 0); 

269 rank 5 and rank 7 holds the shard of coord (1, 1); 

270 

271 The last value to calculate before obtaining the starting offset is the shard linear index. 

272 The starting offset for each rank will be its shard_linear_index * local_tensor_numel. 

273 """ 

274 current_offset = state.offset 

275 offset_incr = self.compute_offset_incr(device_mesh, placements, global_shape) 

276 state.offset = current_offset + offset_incr 

277 

278 def _set_post_op_offset( 

279 self, state: _PhiloxState, global_shape, old_offset: int 

280 ) -> None: 

281 """Sets the RNG to a synchronized state after running the local random op. 

282 Restores the random number generator to a globally consistent state following 

283 local shard execution. Each process must advance its offset by the total element 

284 count of the distributed tensor, measured from the offset value recorded before 

285 the operation began. 

286 

287 Args: 

288 state (`Tensor`): The generator state to modify. 

289 global_shape: The global shape of the distributed tensor. 

290 old_offset (int): The RNG offset before the operation. 

291 

292 Returns: 

293 None 

294 """ 

295 numel = functools.reduce(operator.mul, global_shape, 1) 

296 numel = (numel + 3) // 4 * 4 

297 state.offset = old_offset + numel 

298 

299 def _calc_shard_linear_idx( 

300 self, shard_coord: list[int], shard_size: list[int] 

301 ) -> int: 

302 return _calc_shard_linear_idx(shard_coord, shard_size) 

303 

304 

305def _calc_first_shard_size(device_mesh, placements, global_shape) -> list[int]: 

306 """Calculate the size of the first shard on rank 0. 

307 

308 Args: 

309 device_mesh: The device mesh describing the device topology. 

310 placements: Sequence of Placement objects (Shard, Replicate, etc.). 

311 global_shape: input global shape 

312 

313 Returns: 

314 list[int]: Shape of rank 0's local shard. 

315 """ 

316 local_size_on_rank_0 = list(global_shape) 

317 for idx, placement in enumerate(placements): 

318 if isinstance(placement, Shard): 

319 mesh_dim_size = device_mesh.size(idx) 

320 shard_dim = placement.dim 

321 local_size_on_rank_0[shard_dim], _ = local_shard_size_and_offset( 

322 global_shape[shard_dim], 

323 mesh_dim_size, 

324 0, 

325 ) 

326 return local_size_on_rank_0 

327 

328 

329def _calc_shard_info( 

330 mesh_coordinate, device_mesh, placements 

331): 

332 """Calculate shard information for a specific rank.""" 

333 mesh_size = device_mesh.mesh_shape 

334 # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP 

335 # case. Replace the custom logic with dim_map once we support it. 

336 dim_map = [-1] * device_mesh.ndim 

337 for i, placement in enumerate(placements): 

338 if isinstance(placement, Shard): 

339 shard_dim = placement.dim 

340 if dim_map[shard_dim] == -1: 

341 dim_map[shard_dim] = [i] 

342 else: 

343 mesh_dim_list = dim_map[shard_dim] 

344 if not isinstance(mesh_dim_list, list): 

345 raise TypeError(f"Expected mesh_dim_list to be a list, got {type(mesh_dim_list)}") 

346 mesh_dim_list.append(i) 

347 

348 # Compute shard coordinate: 

349 # The coordinate on each tensor dim is a tuple (idx, range) 

350 # If a DTensor is partitioned on its dim i into n shards, and the current rank 

351 # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i 

352 if mesh_coordinate is None: 

353 raise ValueError("mesh_coordinate must not be None") 

354 shard_idx_by_dim = [] 

355 total_num_shards_by_dim = [] # total number of shards on each tensor dim 

356 for mesh_dim in dim_map: 

357 shard_idx = 0 

358 total_num_shards = 1 

359 # the tensor dim is sharded on more than 1 mesh dim 

360 if isinstance(mesh_dim, list): 

361 rank_coord = [mesh_coordinate[d] for d in mesh_dim] 

362 num_shards = [mesh_size[d] for d in mesh_dim] 

363 # compute the shard idx and total number of shards 

364 for idx, size in zip(rank_coord, num_shards): 

365 shard_idx = shard_idx * size + idx 

366 total_num_shards *= size 

367 

368 shard_idx_by_dim.append(shard_idx) 

369 total_num_shards_by_dim.append(total_num_shards) 

370 return shard_idx_by_dim, total_num_shards_by_dim 

371 

372 

373def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: 

374 # compute shard linear index 

375 shard_linear_idx = 0 

376 shard_coord_stride = 1 

377 for idx, size in zip(reversed(shard_coord), reversed(shard_size)): 

378 shard_linear_idx += idx * shard_coord_stride 

379 shard_coord_stride *= size 

380 

381 return shard_linear_idx 

382 

383 

384def _resolve_device(): 

385 device_handle = platform.get_device_handle() 

386 device_idx = platform.get_rank() % platform.device_count(device_handle) 

387 

388 def get_device(device_idx): 

389 return platform.device(device_idx) 

390 

391 return get_device(device_idx) 

392 

393 

394def local_shard_size_and_offset( 

395 curr_local_size: int, 

396 num_chunks: int, 

397 rank, 

398): 

399 """ 

400 Given the size of the current local tensor (which may already be sharded on some dimensions), 

401 computes the new local shard size and offset given the desired number of chunks 

402 (num_chunks is generally equal to the size of the current sharding dim). 

403 

404 Note: new local shard offset is relative to the current sharded tensor, not the global tensor. 

405 See `_utils.compute_local_shape_and_global_offset` for computing global offset. 

406 

407 Returns (new local shard size, offset) 

408 

409 """ 

410 # Compute the chunk size inline 

411 if curr_local_size % num_chunks == 0: 

412 full_chunk_size = curr_local_size // num_chunks 

413 shard_starting_idx = full_chunk_size * rank 

414 return full_chunk_size, shard_starting_idx 

415 

416 # uneven sharding case 

417 full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks 

418 shard_starting_idx = full_chunk_size * rank 

419 

420 if curr_local_size < shard_starting_idx: 

421 return 0, typing.cast(int, curr_local_size) 

422 local_shard_size = ( 

423 min(curr_local_size, shard_starting_idx + full_chunk_size) 

424 - shard_starting_idx 

425 ) 

426 return local_shard_size, shard_starting_idx 

427 

428 

429_fork_rng_warned_already = False 

430 

431 

432@contextlib.contextmanager 

433def fork_rng( 

434 devices=None, 

435 enabled=True, 

436 device_type="npu", 

437): 

438 """ 

439 Forks the RNG, so that when you return, the RNG is reset 

440 to the state that it was previously in. 

441 

442 Args: 

443 devices (iterable of Device IDs): devices for which to fork 

444 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates 

445 on all devices, but will emit a warning if your machine has a lot 

446 of devices, since this function will run very slowly in that case. 

447 If you explicitly specify devices, this warning will be suppressed 

448 enabled (bool): if ``False``, the RNG is not forked. This is a convenience 

449 argument for easily disabling the context manager without having 

450 to delete it and unindent your Python code under it. 

451 device_type (str): device type str, default is `npu`. As for supported device, 

452 see details in :ref:`accelerator<accelerators>` 

453 """ 

454 

455 device_mod = platform.get_device_handle() 

456 if device_mod is None: 

457 raise RuntimeError( 

458 f"{platform} has no module of `{device_type}`, you should register " 

459 ) 

460 global _fork_rng_warned_already 

461 

462 if not enabled: 

463 yield 

464 return 

465 

466 if devices is None: 

467 num_devices = platform.device_count(device_mod) 

468 if num_devices > 1 and not _fork_rng_warned_already: 

469 _fork_rng_warned_already = True 

470 devices = list(range(num_devices)) 

471 else: 

472 # Protect against user passing us a generator; we need to traverse this 

473 # multiple times but a generator will be exhausted upon first traversal 

474 devices = list(devices) 

475 

476 cpu_rng_state = platform.get_rng_state() 

477 device_rng_states = [platform.get_rng_state(device, device_mod) for device in devices] 

478 

479 try: 

480 yield 

481 finally: 

482 platform.set_rng_state(cpu_rng_state) 

483 for device, device_rng_state in zip(devices, device_rng_states): 

484 platform.set_rng_state(device_rng_state, device, device_mod)