Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_multinomial.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Multinomial operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class MultinomialDistributedOp(DistributedOp): 

24 """Distributed implementation for Multinomial operator.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for Multinomial operator. 

29 

30 Args: 

31 layouts (tuple): Layouts of input tensor. 

32 extra_args (tuple): Arguments for the operator (num_samples, replacement, generator). 

33 Note: logic assumes num_samples affects shape but not the sharding map pattern. 

34 

35 Returns: 

36 Layout: Layout for output tensor. 

37 """ 

38 # Check partial inputs 

39 if not self._allow_partial_inputs: 

40 self._check_partial_inputs(layouts) 

41 layout = layouts[0] 

42 in_tensor_map = layout.alias_tensor_map 

43 ndim = len(in_tensor_map) 

44 

45 # PyTorch multinomial supports 1D or 2D tensors 

46 if ndim not in (1, 2): 

47 raise ValueError(f"For 'multinomial', input dimension must be 1 or 2, but got {ndim}.") 

48 

49 # The last dimension (probability dimension) must NOT be sharded. 

50 # Multinomial sampling requires the full probability distribution (cumulative sum, normalization) 

51 # to be present on the device. If it's sharded, we cannot sample correctly without communication. 

52 # In alias_tensor_map, "None" indicates the dimension is NOT sharded. 

53 prob_dim_map = in_tensor_map[-1] 

54 

55 # Check if the last dimension is sharded (i.e., not "None") 

56 if prob_dim_map != "None": 

57 raise ValueError( 

58 f"For 'multinomial', the last dimension (probability category dimension) " 

59 f"must not be sharded (Replicated). Got layout map: {in_tensor_map}. " 

60 f"Please redistribute the tensor to replicate the last dimension before calling multinomial." 

61 ) 

62 out_tensor_map = () 

63 if ndim == 1: 

64 # Input: (C,) -> Probability distribution 

65 # Output: (num_samples,) -> Sampled indices 

66 # Since the input distribution C is replicated (checked above), the samples generated 

67 # from it have no spatial correlation to shards, so the output is also Replicated ("None"). 

68 out_tensor_map = ("None",) 

69 else: 

70 # Input: (N, C) -> Batch of distributions 

71 # Output: (N, num_samples) -> Batch of sampled indices 

72 # 

73 # Rule: 

74 # 1. Preserve sharding on the Batch dimension (dim 0). If input is data parallel, output is too. 

75 # 2. The new dimension (num_samples) is created locally and is not sharded ("None"). 

76 batch_dim_map = in_tensor_map[0] 

77 out_tensor_map = (batch_dim_map, "None") 

78 

79 output_layout = Layout( 

80 mesh_shape=layout.mesh_shape, 

81 alias_name=layout.alias_name, 

82 rank_list=layout.rank_list 

83 ) 

84 

85 # Create output layout using the inferred tensor map aliases 

86 return output_layout(*out_tensor_map)