Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / dtensor / random.py: 28%
187 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright (c) Meta Platforms, Inc. and affiliates
2"""RNG state management for distributed tensor operations.
4Provides utilities for tracking and synchronizing random number generator states
5across multiple devices in distributed training scenarios.
6"""
8__all__ = [
9 "is_rng_supported_mesh",
10 "OffsetBasedRNGTracker",
11]
13import contextlib
14from logging import getLogger
15import typing
16from typing import Optional
17import functools
18import operator
20from hyper_parallel.core.dtensor.placement_types import Shard
21from hyper_parallel.platform import get_platform
23platform = get_platform()
24DTensorBase = platform.DTensorBase
25Tensor = platform.tensor
27logger = getLogger(__name__)
29_rng_tracker: Optional["_RNGStateTracker"] = None
32def is_rng_supported_mesh() -> bool:
33 """Check if the device mesh supports DTensor random operations.
35 Currently, DTensor random operations are only supported on CUDA and CUDA-like
36 devices. Users should call this function before using DTensor random APIs to
37 verify compatibility.
39 Returns:
40 bool: ``True`` if the device mesh supports DTensor random operations,
41 ``False`` otherwise.
42 """
43 device_handle = platform.get_device_handle()
44 if device_handle and hasattr(device_handle, "set_rng_state"):
45 return True
46 return False
49class _PhiloxState:
50 """
51 Convenience accessor for interpreting the packed bits of (seed: uint64, offset: uint64) in the philox state,
52 which for some reason is actually exposed as a size-16 uint8 tensor.
54 The state is always moved to .cpu since it is necessary for it to be on CPU before applying it back to a generator.
55 """
57 def __init__(self, state: Tensor):
58 self._state = state.to("cpu")
60 @property
61 def state(self):
62 return self._state
64 @property
65 def offset(self) -> int:
66 return int(self._state[8:].view(dtype=platform.tensor_dtype.int64).item())
68 @offset.setter
69 def offset(self, offset: int) -> None:
70 offset_tensor = Tensor([offset], dtype=platform.tensor_dtype.uint64).view(
71 platform.tensor_dtype.uint8
72 ) # device?
73 self._state[8:] = offset_tensor
75 @property
76 def seed(self) -> int:
77 return int(self._state[:8].view(dtype=platform.tensor_dtype.uint64).item())
79 @seed.setter
80 def seed(self, seed: int) -> None:
81 seed_tensor = Tensor([seed], dtype=platform.tensor_dtype.uint64).view(
82 platform.tensor_dtype.uint8
83 )# device
84 self._state[:8] = seed_tensor
87class _RNGStateTracker:
88 """
89 Tracks and manages RNG states for DTensor random operations.
91 Maintains a mapping from operation tags to RNG state tensors (ByteTensor),
92 providing standardized interfaces for state access and modification.
94 The core method `_distribute_region` establishes the proper RNG context
95 when DTensor executes random operators across distributed devices.
96 """
98 def __init__(self, device):
99 self._device = device
100 self._device_handle = platform.get_device_handle()
101 if not self._device_handle:
102 raise RuntimeError(
103 f"{self.__class__.__name__} instantiation requires the presence of "
104 )
105 self._use_distribute_region = True
107 @property
108 def distribute_region_enabled(self) -> bool:
109 return self._use_distribute_region
111 @distribute_region_enabled.setter
112 def distribute_region_enabled(self, value) -> None:
113 self._use_distribute_region = value
115 def _distribute_region(
116 self, device_mesh, placements, global_shape, generator = None
117 ):
118 pass
120 def _manual_seed(self, parallel_seed: int) -> None:
121 pass
124class OffsetBasedRNGTracker(_RNGStateTracker):
125 """
126 This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states
127 should be shared and synchronized among all ranks to respect the semantics of DTensor
128 random operators.
129 """
131 def __init__(
132 self,
133 run_state_sync: bool = True,
134 ):
135 super().__init__(_resolve_device())
136 rng_state = self._get_device_state()
137 if run_state_sync:
138 # synchronize RNG state using rank 0's current one
139 platform.broadcast(rng_state, 0)
140 my_rng_state = self._get_device_state()
141 if not all(my_rng_state == rng_state):
142 logger.warning(
143 "DTensor is synchronizing RNG states of every rank with the state from rank 0. "
144 "This behavior is deprecated. "
145 "Please call `manual_seed()` on every rank that participates in SPMD DTensor Operations with "
146 "the same seed. If using Pipeline Parallelism, each pipelining state would use a different seed, "
147 "but all ranks belonging to one pipeline stage would use the same seed."
148 )
149 self._set_device_state(rng_state)
151 def _get_device_state(self):
152 rng_state = self._device_handle.get_rng_state().to(self._device)
153 return rng_state
155 def _set_device_state(self, state: Tensor):
156 # It seems that the underlying generator wants a cpu tensor but the dtensor code expects `_get_device_state`
157 # to convert to a 'device' tensor, probably because we may use it with our backend comms for sync/debug
158 # for now, we just convert back to cpu here to make sure it always works.
159 self._device_handle.set_rng_state(state.to("cpu"))
161 @contextlib.contextmanager
162 def _distribute_region(
163 self, device_mesh, placements, global_shape, generator = None
164 ):
166 # regular (non-LocalTensor) mode
167 if generator is not None:
168 # This is a little hacky, but for any user-passed generator, we store its state under a unique key,
169 # not because we need to keep a copy of it but because its the easiest way to make it work with the
170 # existing set/get APIs. We also ensure we remove it from rng_states after each _distribute_region.
171 state = _PhiloxState(generator.get_state())
172 else:
173 state = _PhiloxState(self._get_device_state())
175 if self.distribute_region_enabled:
176 old_offset = state.offset
177 self._set_pre_op_offset(state, device_mesh, placements, global_shape)
178 with fork_rng(
179 devices=[self._device], device_type=platform.device_type()
180 ):
181 self._device_handle.set_rng_state(state.state)
182 try:
183 yield # execute the region code
184 finally:
185 # update offset to synchronize among ranks
186 self._set_post_op_offset(state, global_shape, old_offset)
188 else:
189 yield
191 if generator is not None:
192 # ensure we (a) propagate the state advancement back to the user's RNG so its visible and impacts any future
193 # usage of that RNG (dtensor or non-dtensor), (b) drop it from our own cache so that if the user updates
194 # the seed value in their rng and uses it with DTensor again, we always use the latest value
195 generator.set_state(state.state)
196 else:
197 self._set_device_state(state.state)
199 def compute_offset_incr(self, device_mesh, placements, global_shape) -> int:
200 """Compute the per-shard RNG offset increment for the current rank.
202 Based on the shard linear index and local shard size, computes how much to
203 advance the offset so that each shard gets a unique portion of the random stream.
205 Args:
206 device_mesh (DeviceMesh): The device mesh describing the device topology.
207 placements (Sequence[Placement]): The placement strategy for each mesh dimension.
208 global_shape: input global shape
210 Returns:
211 int: The offset increment, 4-byte aligned.
212 """
213 mesh_coordinate = device_mesh.get_coordinate()
214 shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info(
215 mesh_coordinate, device_mesh, placements
216 )
217 shard_linear_idx = self._calc_shard_linear_idx(
218 shard_idx_by_dim, total_num_shards_by_dim
219 )
220 local_size_on_rank_0 = _calc_first_shard_size(device_mesh, placements, global_shape)
221 local_size = functools.reduce(operator.mul, local_size_on_rank_0, 1)
222 return (shard_linear_idx * local_size + 3) // 4 * 4
224 def _set_pre_op_offset(self, state: _PhiloxState, device_mesh, placements, global_shape) -> None:
225 """Set the starting random number generator (RNG) offset for the local shard
226 on the current process before operation execution.The offset value begins from
227 the current accumulated position and increments by the local shard size until
228 covering the total elements of the global distributed tensor. Multiple processes
229 holding replicas of the same shard will share identical starting offset values.
231 Args:
232 state (`Tensor`): The generator state to modify
233 device_mesh (DeviceMesh): The device mesh describing the device topology.
234 placements (Sequence[Placement]): The placement strategy for each mesh dimension.
235 Each element should be a Placement object (Shard, Replicate, Partial, etc.).
236 global_shape: input global shape
238 Returns:
239 None
241 .. warning::
242 The current implementation does not consider memory layout contiguity.
244 Example:
245 take a DTensor of shape [8, 16] as an example. Assume that the DTensor
246 is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]),
247 and the mesh is:
248 [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
249 ``mesh.get_coordinate()`` provides the coordinate of the current rank
250 in the mesh. For example, the coordinate of rank 5 is (1, 0, 1).
252 Another concept to introduce besides rank coordinate is shard coordinate.
253 Each rank holds a local shard of the DTensor. In the example, the DTensor
254 is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and
255 rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each.
256 That being said, the local shard on rank 0 and rank 2 correspond to the same
257 shard of the DTensor. To denote each DTensor shard, we use a shard coordinate
258 (in the example, it will be a tuple (i, j) where shard (i, j) has the slice
259 DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2).
261 Once we have rank coordinate and shard coordinate, we can calculate on each rank
262 what shard of the DTensor the rank holds, with the help of dim_map. The dim_map
263 of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord
264 (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]).
265 Following this calculation,
266 rank 0 and rank 2 holds the shard of coord (0, 0);
267 rank 1 and rank 3 holds the shard of coord (0, 1);
268 rank 4 and rank 6 holds the shard of coord (1, 0);
269 rank 5 and rank 7 holds the shard of coord (1, 1);
271 The last value to calculate before obtaining the starting offset is the shard linear index.
272 The starting offset for each rank will be its shard_linear_index * local_tensor_numel.
273 """
274 current_offset = state.offset
275 offset_incr = self.compute_offset_incr(device_mesh, placements, global_shape)
276 state.offset = current_offset + offset_incr
278 def _set_post_op_offset(
279 self, state: _PhiloxState, global_shape, old_offset: int
280 ) -> None:
281 """Sets the RNG to a synchronized state after running the local random op.
282 Restores the random number generator to a globally consistent state following
283 local shard execution. Each process must advance its offset by the total element
284 count of the distributed tensor, measured from the offset value recorded before
285 the operation began.
287 Args:
288 state (`Tensor`): The generator state to modify.
289 global_shape: The global shape of the distributed tensor.
290 old_offset (int): The RNG offset before the operation.
292 Returns:
293 None
294 """
295 numel = functools.reduce(operator.mul, global_shape, 1)
296 numel = (numel + 3) // 4 * 4
297 state.offset = old_offset + numel
299 def _calc_shard_linear_idx(
300 self, shard_coord: list[int], shard_size: list[int]
301 ) -> int:
302 return _calc_shard_linear_idx(shard_coord, shard_size)
305def _calc_first_shard_size(device_mesh, placements, global_shape) -> list[int]:
306 """Calculate the size of the first shard on rank 0.
308 Args:
309 device_mesh: The device mesh describing the device topology.
310 placements: Sequence of Placement objects (Shard, Replicate, etc.).
311 global_shape: input global shape
313 Returns:
314 list[int]: Shape of rank 0's local shard.
315 """
316 local_size_on_rank_0 = list(global_shape)
317 for idx, placement in enumerate(placements):
318 if isinstance(placement, Shard):
319 mesh_dim_size = device_mesh.size(idx)
320 shard_dim = placement.dim
321 local_size_on_rank_0[shard_dim], _ = local_shard_size_and_offset(
322 global_shape[shard_dim],
323 mesh_dim_size,
324 0,
325 )
326 return local_size_on_rank_0
329def _calc_shard_info(
330 mesh_coordinate, device_mesh, placements
331):
332 """Calculate shard information for a specific rank."""
333 mesh_size = device_mesh.mesh_shape
334 # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
335 # case. Replace the custom logic with dim_map once we support it.
336 dim_map = [-1] * device_mesh.ndim
337 for i, placement in enumerate(placements):
338 if isinstance(placement, Shard):
339 shard_dim = placement.dim
340 if dim_map[shard_dim] == -1:
341 dim_map[shard_dim] = [i]
342 else:
343 mesh_dim_list = dim_map[shard_dim]
344 if not isinstance(mesh_dim_list, list):
345 raise TypeError(f"Expected mesh_dim_list to be a list, got {type(mesh_dim_list)}")
346 mesh_dim_list.append(i)
348 # Compute shard coordinate:
349 # The coordinate on each tensor dim is a tuple (idx, range)
350 # If a DTensor is partitioned on its dim i into n shards, and the current rank
351 # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i
352 if mesh_coordinate is None:
353 raise ValueError("mesh_coordinate must not be None")
354 shard_idx_by_dim = []
355 total_num_shards_by_dim = [] # total number of shards on each tensor dim
356 for mesh_dim in dim_map:
357 shard_idx = 0
358 total_num_shards = 1
359 # the tensor dim is sharded on more than 1 mesh dim
360 if isinstance(mesh_dim, list):
361 rank_coord = [mesh_coordinate[d] for d in mesh_dim]
362 num_shards = [mesh_size[d] for d in mesh_dim]
363 # compute the shard idx and total number of shards
364 for idx, size in zip(rank_coord, num_shards):
365 shard_idx = shard_idx * size + idx
366 total_num_shards *= size
368 shard_idx_by_dim.append(shard_idx)
369 total_num_shards_by_dim.append(total_num_shards)
370 return shard_idx_by_dim, total_num_shards_by_dim
373def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int:
374 # compute shard linear index
375 shard_linear_idx = 0
376 shard_coord_stride = 1
377 for idx, size in zip(reversed(shard_coord), reversed(shard_size)):
378 shard_linear_idx += idx * shard_coord_stride
379 shard_coord_stride *= size
381 return shard_linear_idx
384def _resolve_device():
385 device_handle = platform.get_device_handle()
386 device_idx = platform.get_rank() % platform.device_count(device_handle)
388 def get_device(device_idx):
389 return platform.device(device_idx)
391 return get_device(device_idx)
394def local_shard_size_and_offset(
395 curr_local_size: int,
396 num_chunks: int,
397 rank,
398):
399 """
400 Given the size of the current local tensor (which may already be sharded on some dimensions),
401 computes the new local shard size and offset given the desired number of chunks
402 (num_chunks is generally equal to the size of the current sharding dim).
404 Note: new local shard offset is relative to the current sharded tensor, not the global tensor.
405 See `_utils.compute_local_shape_and_global_offset` for computing global offset.
407 Returns (new local shard size, offset)
409 """
410 # Compute the chunk size inline
411 if curr_local_size % num_chunks == 0:
412 full_chunk_size = curr_local_size // num_chunks
413 shard_starting_idx = full_chunk_size * rank
414 return full_chunk_size, shard_starting_idx
416 # uneven sharding case
417 full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks
418 shard_starting_idx = full_chunk_size * rank
420 if curr_local_size < shard_starting_idx:
421 return 0, typing.cast(int, curr_local_size)
422 local_shard_size = (
423 min(curr_local_size, shard_starting_idx + full_chunk_size)
424 - shard_starting_idx
425 )
426 return local_shard_size, shard_starting_idx
429_fork_rng_warned_already = False
432@contextlib.contextmanager
433def fork_rng(
434 devices=None,
435 enabled=True,
436 device_type="npu",
437):
438 """
439 Forks the RNG, so that when you return, the RNG is reset
440 to the state that it was previously in.
442 Args:
443 devices (iterable of Device IDs): devices for which to fork
444 the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
445 on all devices, but will emit a warning if your machine has a lot
446 of devices, since this function will run very slowly in that case.
447 If you explicitly specify devices, this warning will be suppressed
448 enabled (bool): if ``False``, the RNG is not forked. This is a convenience
449 argument for easily disabling the context manager without having
450 to delete it and unindent your Python code under it.
451 device_type (str): device type str, default is `npu`. As for supported device,
452 see details in :ref:`accelerator<accelerators>`
453 """
455 device_mod = platform.get_device_handle()
456 if device_mod is None:
457 raise RuntimeError(
458 f"{platform} has no module of `{device_type}`, you should register "
459 )
460 global _fork_rng_warned_already
462 if not enabled:
463 yield
464 return
466 if devices is None:
467 num_devices = platform.device_count(device_mod)
468 if num_devices > 1 and not _fork_rng_warned_already:
469 _fork_rng_warned_already = True
470 devices = list(range(num_devices))
471 else:
472 # Protect against user passing us a generator; we need to traverse this
473 # multiple times but a generator will be exhausted upon first traversal
474 devices = list(devices)
476 cpu_rng_state = platform.get_rng_state()
477 device_rng_states = [platform.get_rng_state(device, device_mod) for device in devices]
479 try:
480 yield
481 finally:
482 platform.set_rng_state(cpu_rng_state)
483 for device, device_rng_state in zip(devices, device_rng_states):
484 platform.set_rng_state(device_rng_state, device, device_mod)