Coverage for hyper_parallel / core / shard / ops / parallel_multinomial.py: 95%

19 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Multinomial operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class MultinomialDistributedOp(DistributedOp): 

24 """Distributed implementation for Multinomial operator.""" 

25 

26 def infer_layout(self, layouts, extra_args): 

27 """ 

28 Infer output layout for Multinomial operator. 

29 

30 Args: 

31 layouts (tuple): Layouts of input tensor. 

32 extra_args (tuple): Arguments for the operator (num_samples, replacement, generator). 

33 Note: logic assumes num_samples affects shape but not the sharding map pattern. 

34 

35 Returns: 

36 Layout: Layout for output tensor. 

37 """ 

38 layout = layouts[0] 

39 in_tensor_map = layout.alias_tensor_map 

40 ndim = len(in_tensor_map) 

41 

42 # PyTorch multinomial supports 1D or 2D tensors 

43 if ndim not in (1, 2): 

44 raise ValueError(f"For 'multinomial', input dimension must be 1 or 2, but got {ndim}.") 

45 

46 # The last dimension (probability dimension) must NOT be sharded. 

47 # Multinomial sampling requires the full probability distribution (cumulative sum, normalization) 

48 # to be present on the device. If it's sharded, we cannot sample correctly without communication. 

49 # In alias_tensor_map, "None" indicates the dimension is NOT sharded. 

50 prob_dim_map = in_tensor_map[-1] 

51 

52 # Check if the last dimension is sharded (i.e., not "None") 

53 if prob_dim_map != "None": 

54 raise ValueError( 

55 f"For 'multinomial', the last dimension (probability category dimension) " 

56 f"must not be sharded (Replicated). Got layout map: {in_tensor_map}. " 

57 f"Please redistribute the tensor to replicate the last dimension before calling multinomial." 

58 ) 

59 

60 out_tensor_map = () 

61 if ndim == 1: 

62 # Input: (C,) -> Probability distribution 

63 # Output: (num_samples,) -> Sampled indices 

64 # Since the input distribution C is replicated (checked above), the samples generated 

65 # from it have no spatial correlation to shards, so the output is also Replicated ("None"). 

66 out_tensor_map = ("None",) 

67 else: 

68 # Input: (N, C) -> Batch of distributions 

69 # Output: (N, num_samples) -> Batch of sampled indices 

70 # 

71 # Rule: 

72 # 1. Preserve sharding on the Batch dimension (dim 0). If input is data parallel, output is too. 

73 # 2. The new dimension (num_samples) is created locally and is not sharded ("None"). 

74 batch_dim_map = in_tensor_map[0] 

75 out_tensor_map = (batch_dim_map, "None") 

76 

77 output_layout = Layout( 

78 mesh_shape=layout.mesh_shape, 

79 alias_name=layout.alias_name, 

80 rank_list=layout.rank_list 

81 ) 

82 

83 # Create output layout using the inferred tensor map aliases 

84 return output_layout(*out_tensor_map)