Coverage for hyper_parallel / core / shard / ops / parallel_embedding.py: 96%
26 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Embedding operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
23class EmbeddingDistributedOp(DistributedOp):
24 """Distributed implementation for Embedding operator."""
26 def infer_layout(self, layouts, extra_args):
27 """
28 Infer output layout for Embedding operator.
30 Args:
31 layouts (tuple): Layouts of input tensors.
32 extra_args (dict): Additional arguments.
34 Returns:
35 Layout: Layout for output tensor
36 """
37 # Step 1: Validate input length FIRST
38 if len(layouts) < 2:
39 raise ValueError(f"Embedding requires at least 2 layouts (input, weight), but got {len(layouts)}")
42 # Step 2: Now it is safe to extract common weight layout info
43 # MS: layouts[1] is weight
44 # Torch: embedding(input, weight), so layouts[1] is weight
45 w_layout = layouts[1]
46 w_dict = w_layout.to_dict()
47 w_tensor_map = w_dict["tensor_map"]
48 w_aliases = w_dict["alias_name"]
49 mesh_shape = w_dict["mesh_shape"]
50 rank_list = w_dict["rank_list"]
52 def idx_to_alias(idx, aliases):
53 if idx == -1:
54 return "None"
55 return aliases[len(aliases) - idx - 1]
57 output_map = ()
59 # Step 3: Specific Logic Calculation
60 # Inputs: input (indices), weight (table), [padding_idx, max_norm, ...]
61 input_layout = layouts[0]
62 inp_tensor_map = input_layout.tensor_map
64 # Output Layout Logic:
65 # 1. Inherit the distribution of the input indices (Batch/Seq dimensions)
66 output_map += inp_tensor_map
68 # 2. Inherit the distribution of the embedding dimension from the weight table.
69 if len(w_tensor_map) > 0:
70 embed_dim_map = w_tensor_map[-1]
71 output_map += (embed_dim_map,)
73 # Reconstruct final Layout object
74 output_aliases_tuple = tuple(idx_to_alias(idx, w_aliases) for idx in output_map)
76 output_layout = Layout(
77 mesh_shape=mesh_shape,
78 alias_name=w_aliases,
79 rank_list=rank_list,
80 )
82 return output_layout(*output_aliases_tuple)