Coverage for hyper_parallel / core / shard / ops / parallel_multinomial.py: 95%
19 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Multinomial operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
23class MultinomialDistributedOp(DistributedOp):
24 """Distributed implementation for Multinomial operator."""
26 def infer_layout(self, layouts, extra_args):
27 """
28 Infer output layout for Multinomial operator.
30 Args:
31 layouts (tuple): Layouts of input tensor.
32 extra_args (tuple): Arguments for the operator (num_samples, replacement, generator).
33 Note: logic assumes num_samples affects shape but not the sharding map pattern.
35 Returns:
36 Layout: Layout for output tensor.
37 """
38 layout = layouts[0]
39 in_tensor_map = layout.alias_tensor_map
40 ndim = len(in_tensor_map)
42 # PyTorch multinomial supports 1D or 2D tensors
43 if ndim not in (1, 2):
44 raise ValueError(f"For 'multinomial', input dimension must be 1 or 2, but got {ndim}.")
46 # The last dimension (probability dimension) must NOT be sharded.
47 # Multinomial sampling requires the full probability distribution (cumulative sum, normalization)
48 # to be present on the device. If it's sharded, we cannot sample correctly without communication.
49 # In alias_tensor_map, "None" indicates the dimension is NOT sharded.
50 prob_dim_map = in_tensor_map[-1]
52 # Check if the last dimension is sharded (i.e., not "None")
53 if prob_dim_map != "None":
54 raise ValueError(
55 f"For 'multinomial', the last dimension (probability category dimension) "
56 f"must not be sharded (Replicated). Got layout map: {in_tensor_map}. "
57 f"Please redistribute the tensor to replicate the last dimension before calling multinomial."
58 )
60 out_tensor_map = ()
61 if ndim == 1:
62 # Input: (C,) -> Probability distribution
63 # Output: (num_samples,) -> Sampled indices
64 # Since the input distribution C is replicated (checked above), the samples generated
65 # from it have no spatial correlation to shards, so the output is also Replicated ("None").
66 out_tensor_map = ("None",)
67 else:
68 # Input: (N, C) -> Batch of distributions
69 # Output: (N, num_samples) -> Batch of sampled indices
70 #
71 # Rule:
72 # 1. Preserve sharding on the Batch dimension (dim 0). If input is data parallel, output is too.
73 # 2. The new dimension (num_samples) is created locally and is not sharded ("None").
74 batch_dim_map = in_tensor_map[0]
75 out_tensor_map = (batch_dim_map, "None")
77 output_layout = Layout(
78 mesh_shape=layout.mesh_shape,
79 alias_name=layout.alias_name,
80 rank_list=layout.rank_list
81 )
83 # Create output layout using the inferred tensor map aliases
84 return output_layout(*out_tensor_map)