Coverage for hyper_parallel / core / random.py: 77%
185 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright (c) Meta Platforms, Inc. and affiliates
2"""RNG state management for distributed tensor operations.
4Provides utilities for tracking and synchronizing random number generator states
5across multiple devices in distributed training scenarios.
6"""
7import contextlib
8from logging import getLogger
9import typing
10from typing import Optional
11import functools
12import operator
14from hyper_parallel.core.placement_types import Shard
15from hyper_parallel.platform import get_platform
17platform = get_platform()
18DTensorBase = platform.DTensorBase
19Tensor = platform.tensor
21logger = getLogger(__name__)
23__all__ = [
24 "is_rng_supported_mesh",
25 "OffsetBasedRNGTracker",
26]
28_rng_tracker: Optional["_RNGStateTracker"] = None
31def is_rng_supported_mesh() -> bool:
32 """Check if the device mesh supports DTensor random operations.
34 Currently, DTensor random operations are only supported on CUDA and CUDA-like
35 devices. Users should call this function before using DTensor random APIs to
36 verify compatibility.
38 Returns:
39 bool: ``True`` if the device mesh supports DTensor random operations,
40 ``False`` otherwise.
41 """
42 device_handle = platform.get_device_handle()
43 if device_handle and hasattr(device_handle, "set_rng_state"):
44 return True
45 return False
48class _PhiloxState:
49 """
50 Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state,
51 which for some reason is actually exposed as a size-16 uint8 tensor.
53 The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator.
54 """
56 def __init__(self, state: Tensor):
57 self._state = state.to("cpu")
59 @property
60 def state(self):
61 return self._state
63 @property
64 def offset(self) -> int:
65 return int(self._state[8:].view(dtype=platform.tensor_dtype.int64).item())
67 @offset.setter
68 def offset(self, offset: int) -> None:
69 offset_tensor = Tensor([offset], dtype=platform.tensor_dtype.uint64).view(
70 platform.tensor_dtype.uint8
71 ) # device?
72 self._state[8:] = offset_tensor
74 @property
75 def seed(self) -> int:
76 return int(self._state[:8].view(dtype=platform.tensor_dtype.uint64).item())
78 @seed.setter
79 def seed(self, seed: int) -> None:
80 seed_tensor = Tensor([seed], dtype=platform.tensor_dtype.uint64).view(
81 platform.tensor_dtype.uint8
82 )# device
83 self._state[:8] = seed_tensor
86class _RNGStateTracker:
87 """
88 Tracks and manages RNG states for DTensor random operations.
90 Maintains a mapping from operation tags to RNG state tensors (ByteTensor),
91 providing standardized interfaces for state access and modification.
93 The core method `_distribute_region` establishes the proper RNG context
94 when DTensor executes random operators across distributed devices.
95 """
97 def __init__(self, device):
98 self._device = device
99 self._device_handle = platform.get_device_handle()
100 if not self._device_handle:
101 raise RuntimeError(
102 f"{self.__class__.__name__} instantiation requires the presence of "
103 )
104 self._use_distribute_region = True
106 @property
107 def distribute_region_enabled(self) -> bool:
108 return self._use_distribute_region
110 @distribute_region_enabled.setter
111 def distribute_region_enabled(self, value) -> None:
112 self._use_distribute_region = value
114 def _distribute_region(
115 self, device_mesh, placements, global_shape, generator = None
116 ):
117 pass
119 def _manual_seed(self, parallel_seed: int) -> None:
120 pass
123class OffsetBasedRNGTracker(_RNGStateTracker):
124 """
125 This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states
126 should be shared and synchronized among all ranks to respect the semantics of DTensor
127 random operators.
128 """
130 def __init__(
131 self,
132 run_state_sync: bool = True,
133 ):
134 super().__init__(_resolve_device())
135 rng_state = self._get_device_state()
136 if run_state_sync:
137 # synchronize RNG state using rank 0's current one
138 platform.broadcast(rng_state, 0)
139 my_rng_state = self._get_device_state()
140 if not all(my_rng_state == rng_state):
141 logger.warning(
142 "DTensor is synchronizing RNG states of every rank with the state from rank 0. "
143 "This behavior is deprecated. "
144 "Please call `manual_seed()` on every rank that participates in SPMD DTensor Operations with "
145 "the same seed. If using Pipeline Parallelism, each pipelining state would use a different seed, "
146 "but all ranks belonging to one pipeline stage would use the same seed."
147 )
148 self._set_device_state(rng_state)
150 def _get_device_state(self):
151 rng_state = self._device_handle.get_rng_state().to(self._device)
152 return rng_state
154 def _set_device_state(self, state: Tensor):
155 # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state`
156 # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug
157 # for now, we just convert back to cpu here to make sure it always works.
158 self._device_handle.set_rng_state(state.to("cpu"))
160 @contextlib.contextmanager
161 def _distribute_region(
162 self, device_mesh, placements, global_shape, generator = None
163 ):
165 # regular (non-LocalTensor) mode
166 if generator is not None:
167 # This is a little hacky, but for any user-passed generator, we store its state under a unique key,
168 # not because we need to keep a copy of it but because its the easiest way to make it work with the
169 # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region.
170 state = _PhiloxState(generator.get_state())
171 else:
172 state = _PhiloxState(self._get_device_state())
174 if self.distribute_region_enabled:
175 old_offset = state.offset
176 self._set_pre_op_offset(state, device_mesh, placements, global_shape)
177 with fork_rng(
178 devices=[self._device], device_type=platform.device_type()
179 ):
180 self._device_handle.set_rng_state(state.state)
181 try:
182 yield # execute the region code
183 finally:
184 # update offset to synchronize among ranks
185 self._set_post_op_offset(state, device_mesh, old_offset)
187 else:
188 yield
190 if generator is not None:
191 # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future
192 # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates
193 # the seed value in their rng and uses it with DTensor again, we always use the latest value
194 generator.set_state(state.state)
195 else:
196 self._set_device_state(state.state)
198 def _set_pre_op_offset(self, state: _PhiloxState, device_mesh, placements, global_shape) -> None:
199 """Set the starting random number generator (RNG) offset for the local shard
200 on the current process before operation execution.The offset value begins from
201 the current accumulated position and increments by the local shard size until
202 covering the total elements of the global distributed tensor. Multiple processes
203 holding replicas of the same shard will share identical starting offset values.
205 Args:
206 state (`Tensor`): The generator state to modify
207 device_mesh (DeviceMesh): The device mesh describing the device topology.
208 placements (Sequence[Placement]): The placement strategy for each mesh dimension.
209 Each element should be a Placement object (Shard, Replicate, Partial, etc.).
210 global_shape: input global shape
212 Returns:
213 None
215 .. warning::
216 The current implementation does not consider memory layout contiguity.
218 Example:
219 take a DTensor of shape [8, 16] as an example. Assume that the DTensor
220 is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]),
221 and the mesh is:
222 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
223 ``mesh.get_coordinate()`` provides the coordinate of the current rank
224 in the mesh. For example, the coordinate of rank 5 is (1, 0, 1).
226 Another concept to introduce besides rank coordinate is shard coordinate.
227 Each rank holds a local shard of the DTensor. In the example, the DTensor
228 is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and
229 rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each.
230 That being said, the local shard on rank 0 and rank 2 correspond to the same
231 shard of the DTensor. To denote each DTensor shard, we use a shard coordinate
232 (in the example, it will be a tuple (i, j) where shard (i, j) has the slice
233 DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2).
235 Once we have rank coordinate and shard coordinate, we can calculate on each rank
236 what shard of the DTensor the rank holds, with the help of dim_map. The dim_map
237 of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord
238 (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]).
239 Following this calculation,
240 rank 0 and rank 2 holds the shard of coord (0, 0);
241 rank 1 and rank 3 holds the shard of coord (0, 1);
242 rank 4 and rank 6 holds the shard of coord (1, 0);
243 rank 5 and rank 7 holds the shard of coord (1, 1);
245 The last value to calculate before obtaining the starting offset is the shard linear index.
246 The starting offset for each rank will be its shard_linear_index * local_tensor_numel.
247 """
248 mesh = device_mesh
249 mesh_coordinate = mesh.get_coordinate()
251 # Compute shard index and total number of shards on each tensor dim
252 shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info(
253 mesh_coordinate, device_mesh, placements
254 )
256 # compute shard linear index
257 shard_linear_idx = self._calc_shard_linear_idx(
258 shard_idx_by_dim, total_num_shards_by_dim
259 )
261 # compute starting offset using the first shard's size
262 local_size_on_rank_0 = _calc_first_shard_size(device_mesh, placements, global_shape)
264 local_size = functools.reduce(operator.mul, local_size_on_rank_0, 1)
266 # get current RNG offset
267 current_offset = state.offset
269 offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4
270 state.offset = current_offset + offset_incr
272 def _set_post_op_offset(
273 self, state: _PhiloxState, device_mesh, old_offset: int
274 ) -> None:
275 """Sets the RNG to a synchronized state after running the local random op.
276 Restores the random number generator to a globally consistent state following
277 local shard execution. Each process must advance its offset by the total element
278 count of the distributed tensor, measured from the offset value recorded before
279 the operation began.
281 Args:
282 state (`Tensor`): The generator state to modify.
283 device_mesh (DeviceMesh): The device mesh describing the device topology.
285 Returns:
286 None
287 """
288 dtensor_shape = device_mesh.mesh_shape
290 numel = functools.reduce(operator.mul, dtensor_shape, 1)
291 numel = (numel + 3) // 4 * 4
292 state.offset = old_offset + numel
294 def _calc_shard_linear_idx(
295 self, shard_coord: list[int], shard_size: list[int]
296 ) -> int:
297 return _calc_shard_linear_idx(shard_coord, shard_size)
300def _calc_first_shard_size(device_mesh, placements, global_shape) -> list[int]:
301 """Calculate the size of the first shard on rank 0.
303 Args:
304 device_mesh: The device mesh describing the device topology.
305 placements: Sequence of Placement objects (Shard, Replicate, etc.).
306 global_shape: input global shape
308 Returns:
309 list[int]: Shape of rank 0's local shard.
310 """
311 local_size_on_rank_0 = list(global_shape)
312 for idx, placement in enumerate(placements):
313 if isinstance(placement, Shard):
314 mesh_dim_size = device_mesh.size(idx)
315 shard_dim = placement.dim
316 local_size_on_rank_0[shard_dim], _ = local_shard_size_and_offset(
317 device_mesh.mesh_shape[shard_dim],
318 mesh_dim_size,
319 0,
320 )
321 return local_size_on_rank_0
324def _calc_shard_info(
325 mesh_coordinate, device_mesh, placements
326):
327 """Calculate shard information for a specific rank."""
328 mesh_size = device_mesh.mesh_shape
329 # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
330 # case. Replace the custom logic with dim_map once we support it.
331 dim_map = [-1] * device_mesh.ndim
332 for i, placement in enumerate(placements):
333 if isinstance(placement, Shard):
334 shard_dim = placement.dim
335 if dim_map[shard_dim] == -1:
336 dim_map[shard_dim] = [i]
337 else:
338 mesh_dim_list = dim_map[shard_dim]
339 assert isinstance(mesh_dim_list, list)
340 mesh_dim_list.append(i)
342 # Compute shard coordinate:
343 # The coordinate on each tensor dim is a tuple (idx, range)
344 # If a DTensor is partitioned on its dim i into n shards, and the current rank
345 # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i
346 assert mesh_coordinate is not None
347 shard_idx_by_dim = []
348 total_num_shards_by_dim = [] # total number of shards on each tensor dim
349 for mesh_dim in dim_map:
350 shard_idx = 0
351 total_num_shards = 1
352 # the tensor dim is sharded on more than 1 mesh dim
353 if isinstance(mesh_dim, list):
354 rank_coord = [mesh_coordinate[d] for d in mesh_dim]
355 num_shards = [mesh_size[d] for d in mesh_dim]
356 # compute the shard idx and total number of shards
357 for idx, size in zip(rank_coord, num_shards):
358 shard_idx = shard_idx * size + idx
359 total_num_shards *= size
361 shard_idx_by_dim.append(shard_idx)
362 total_num_shards_by_dim.append(total_num_shards)
363 return shard_idx_by_dim, total_num_shards_by_dim
366def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int:
367 # compute shard linear index
368 shard_linear_idx = 0
369 shard_coord_stride = 1
370 for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
371 shard_linear_idx += idx * shard_coord_stride
372 shard_coord_stride *= size
374 return shard_linear_idx
377def _resolve_device():
378 device_handle = platform.get_device_handle()
379 device_idx = platform.get_rank() % platform.device_count(device_handle)
381 def get_device(device_idx):
382 return platform.device(device_idx)
384 return get_device(device_idx)
386def local_shard_size_and_offset(
387 curr_local_size: int,
388 num_chunks: int,
389 rank,
390):
391 """
392 Given the size of the current local tensor (which may already be sharded on some dimensions),
393 computes the new local shard size and offset given the desired number of chunks
394 (num_chunks is generally equal to the size of the current sharding dim).
396 Note: new local shard offset is relative to the current sharded tensor, not the global tensor.
397 See `_utils.compute_local_shape_and_global_offset` for computing global offset.
399 Returns (new local shard size, offset)
401 """
402 # Compute the chunk size inline
403 if curr_local_size % num_chunks == 0:
404 full_chunk_size = curr_local_size // num_chunks
405 shard_starting_idx = full_chunk_size * rank
406 return full_chunk_size, shard_starting_idx
408 # uneven sharding case
409 full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks
410 shard_starting_idx = full_chunk_size * rank
412 if curr_local_size < shard_starting_idx:
413 return 0, typing.cast(int, curr_local_size)
414 local_shard_size = (
415 min(curr_local_size, shard_starting_idx + full_chunk_size)
416 - shard_starting_idx
417 )
418 return local_shard_size, shard_starting_idx
421_fork_rng_warned_already = False
423@contextlib.contextmanager
424def fork_rng(
425 devices=None,
426 enabled=True,
427 device_type="npu",
428):
429 """
430 Forks the RNG, so that when you return, the RNG is reset
431 to the state that it was previously in.
433 Args:
434 devices (iterable of Device IDs): devices for which to fork
435 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
436 on all devices, but will emit a warning if your machine has a lot
437 of devices, since this function will run very slowly in that case.
438 If you explicitly specify devices, this warning will be suppressed
439 enabled (bool): if ``False``, the RNG is not forked. This is a convenience
440 argument for easily disabling the context manager without having
441 to delete it and unindent your Python code under it.
442 device_type (str): device type str, default is `npu`. As for supported device,
443 see details in :ref:`accelerator<accelerators>`
444 """
446 device_mod = platform.get_device_handle()
447 if device_mod is None:
448 raise RuntimeError(
449 f"{platform} has no module of `{device_type}`, you should register "
450 )
451 global _fork_rng_warned_already
453 if not enabled:
454 yield
455 return
457 if devices is None:
458 num_devices = platform.device_count(device_mod)
459 if num_devices > 1 and not _fork_rng_warned_already:
460 _fork_rng_warned_already = True
461 devices = list(range(num_devices))
462 else:
463 # Protect against user passing us a generator; we need to traverse this
464 # multiple times but a generator will be exhausted upon first traversal
465 devices = list(devices)
467 cpu_rng_state = platform.get_rng_state()
468 device_rng_states = [platform.get_rng_state(device, device_mod) for device in devices]
470 try:
471 yield
472 finally:
473 platform.set_rng_state(cpu_rng_state)
474 for device, device_rng_state in zip(devices, device_rng_states):
475 platform.set_rng_state(device_rng_state, device, device_mod)