Coverage for hyper_parallel / core / shard / ops / parallel_embedding.py: 96%

26 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Embedding operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class EmbeddingDistributedOp(DistributedOp): 

24 """Distributed implementation for Embedding operator.""" 

25 

26 def infer_layout(self, layouts, extra_args): 

27 """ 

28 Infer output layout for Embedding operator. 

29 

30 Args: 

31 layouts (tuple): Layouts of input tensors. 

32 extra_args (dict): Additional arguments. 

33 

34 Returns: 

35 Layout: Layout for output tensor 

36 """ 

37 # Step 1: Validate input length FIRST 

38 if len(layouts) < 2: 

39 raise ValueError(f"Embedding requires at least 2 layouts (input, weight), but got {len(layouts)}") 

40 

41 

42 # Step 2: Now it is safe to extract common weight layout info 

43 # MS: layouts[1] is weight 

44 # Torch: embedding(input, weight), so layouts[1] is weight 

45 w_layout = layouts[1] 

46 w_dict = w_layout.to_dict() 

47 w_tensor_map = w_dict["tensor_map"] 

48 w_aliases = w_dict["alias_name"] 

49 mesh_shape = w_dict["mesh_shape"] 

50 rank_list = w_dict["rank_list"] 

51 

52 def idx_to_alias(idx, aliases): 

53 if idx == -1: 

54 return "None" 

55 return aliases[len(aliases) - idx - 1] 

56 

57 output_map = () 

58 

59 # Step 3: Specific Logic Calculation 

60 # Inputs: input (indices), weight (table), [padding_idx, max_norm, ...] 

61 input_layout = layouts[0] 

62 inp_tensor_map = input_layout.tensor_map 

63 

64 # Output Layout Logic: 

65 # 1. Inherit the distribution of the input indices (Batch/Seq dimensions) 

66 output_map += inp_tensor_map 

67 

68 # 2. Inherit the distribution of the embedding dimension from the weight table. 

69 if len(w_tensor_map) > 0: 

70 embed_dim_map = w_tensor_map[-1] 

71 output_map += (embed_dim_map,) 

72 

73 # Reconstruct final Layout object 

74 output_aliases_tuple = tuple(idx_to_alias(idx, w_aliases) for idx in output_map) 

75 

76 output_layout = Layout( 

77 mesh_shape=mesh_shape, 

78 alias_name=w_aliases, 

79 rank_list=rank_list, 

80 ) 

81 

82 return output_layout(*output_aliases_tuple)