Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_embedding.py: 65%

77 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for the embedding operator. 

17This module handles Tensor Parallelism for Embedding layers (Column/Row Parallel). 

18""" 

19 

20from hyper_parallel.core.dtensor.layout import Layout 

21from .parallel_ops import DistributedOp 

22 

23class EmbeddingDistributedOp(DistributedOp): 

24 """ 

25 Distributed implementation for embedding operators. 

26 Supports Column Parallelism (CP) and Row Parallelism (RP). 

27 """ 

28 

29 def infer_layout(self, layouts, extra_args=None): 

30 """ 

31 Infers the output layout (sharding state) based on input and weight layouts. 

32 """ 

33 if not layouts or len(layouts) < 2: 

34 raise ValueError( 

35 f"Operation {self.op_name}: requires both input and weight layouts." 

36 ) 

37 

38 input_layout = layouts[0] 

39 weight_layout = layouts[1] 

40 weight_tensor_map = weight_layout.tensor_map 

41 

42 # weight_tensor_map: [vocab_size_dim, embed_dim_dim] 

43 w_shard_vocab_axis = weight_tensor_map[0] 

44 w_shard_embed_axis = weight_tensor_map[1] 

45 

46 # Output shape is [*input_shape, embed_dim] 

47 output_tensor_map = list(input_layout.tensor_map) 

48 output_tensor_map.append(w_shard_embed_axis) 

49 

50 output_layout = Layout( 

51 mesh_shape=input_layout.mesh_shape, 

52 alias_name=input_layout.alias_name, 

53 rank_list=input_layout.rank_list 

54 ) 

55 output_layout.set_tensor_map(tuple(output_tensor_map)) 

56 

57 # If vocab is sharded (Row Parallelism), output is in Partial Sum state 

58 if w_shard_vocab_axis != -1: 

59 # pylint: disable=protected-access 

60 output_layout._partial = list(input_layout.partial) 

61 vocab_axis_name = input_layout.alias_name[len(input_layout.alias_name) - 1 - w_shard_vocab_axis] 

62 output_layout.set_partial_by_dev_axis(vocab_axis_name, 'sum') 

63 # pylint: disable=protected-access 

64 output_layout._alias_tensor_map = output_layout._build_readable_tensor_map() 

65 # pylint: disable=protected-access 

66 output_layout.tensor_map_to_placement() 

67 output_layout.update_compact_str() 

68 

69 return output_layout 

70 

71 def _parse_params(self, args, kwargs): 

72 """ 

73 Extracts padding_idx, max_norm, and scale_grad from args or kwargs. 

74 F.embedding signature: (input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq...) 

75 """ 

76 padding_idx = args[2] if len(args) > 2 else kwargs.get('padding_idx', None) 

77 max_norm = args[3] if len(args) > 3 else kwargs.get('max_norm', None) 

78 scale_grad = args[5] if len(args) > 5 else kwargs.get('scale_grad_by_freq', False) 

79 return padding_idx, max_norm, scale_grad 

80 

81 def _handle_rp_input(self, input_tensor, weight_tensor, weight_layout, w_shard_vocab_axis, 

82 new_args, kwargs, is_args_pad, padding_idx): 

83 """ 

84 Processes Row Parallelism input: shifts indices, handles padding, and generates masks. 

85 """ 

86 mesh = weight_layout.mesh 

87 mesh_dim_idx = len(mesh.mesh_shape) - 1 - w_shard_vocab_axis 

88 vocab_coord = mesh.get_local_rank(mesh_dim_idx) 

89 

90 vocab_size_per_partition = weight_tensor.shape[0] 

91 vocab_start_index = int(vocab_coord * vocab_size_per_partition) 

92 vocab_end_index = int(vocab_start_index + vocab_size_per_partition) 

93 

94 # Map global padding_idx to local rank range 

95 if padding_idx is not None: 

96 if vocab_start_index <= padding_idx < vocab_end_index: 

97 mapped_padding_idx = int(padding_idx - vocab_start_index) 

98 if is_args_pad: 

99 new_args[2] = mapped_padding_idx 

100 else: 

101 kwargs['padding_idx'] = mapped_padding_idx 

102 else: 

103 if is_args_pad: 

104 new_args[2] = None 

105 else: 

106 kwargs.pop('padding_idx', None) 

107 

108 # Calculate out-of-bounds mask 

109 mask = (input_tensor >= vocab_start_index) & (input_tensor < vocab_end_index) 

110 

111 # Cross-platform cast to matching int dtype 

112 mask_int = mask.to(input_tensor.dtype) if hasattr(mask, "to") else mask.astype(input_tensor.dtype) 

113 

114 # Shift global indices to local range using native scalar broadcast 

115 local_input = input_tensor - vocab_start_index 

116 

117 # Zero out invalid indices mathematically instead of using .where() or clamp(). 

118 # This prevents NPU out-of-bounds memory access during the embedding lookup 

119 # while keeping the code perfectly backend-neutral. 

120 local_input = local_input * mask_int 

121 

122 return local_input, mask_int 

123 

124 def get_expand_impl(self, func, infer_result, layouts, extra_args=None): 

125 """ 

126 Returns the execution implementation wrapper. 

127 Helper functions are used to keep Cyclomatic Complexity (CCN) low. 

128 """ 

129 weight_layout = layouts[1] 

130 w_shard_vocab_axis = weight_layout.tensor_map[0] 

131 w_shard_embed_axis = weight_layout.tensor_map[1] 

132 

133 # Use native implementation if no weight sharding is applied 

134 if w_shard_vocab_axis == -1 and w_shard_embed_axis == -1: 

135 return None 

136 

137 def distributed_embedding_impl(*args, **kwargs): 

138 input_tensor, weight_tensor = args[0], args[1] 

139 new_args, new_kwargs = list(args), kwargs.copy() 

140 

141 # 1. Parameter extraction and validation 

142 padding_idx, max_norm, scale_grad = self._parse_params(args, kwargs) 

143 

144 # Check for max_norm with specific error messages for CP and RP 

145 if max_norm is not None: 

146 if w_shard_embed_axis != -1: 

147 raise ValueError("Column-Parallel Embedding does not support `max_norm` parameter.") 

148 if w_shard_vocab_axis != -1: 

149 raise ValueError("Row-Parallel Embedding does not support `max_norm` parameter.") 

150 

151 # Check for scale_grad_by_freq with RP 

152 if scale_grad and w_shard_vocab_axis != -1: 

153 raise ValueError("Row-Parallel Embedding does not support `scale_grad_by_freq=True`.") 

154 

155 # 2. Row Parallel Processing 

156 input_mask_int = None 

157 if w_shard_vocab_axis != -1: 

158 is_args_pad = len(args) > 2 

159 mapped_input, input_mask_int = self._handle_rp_input( 

160 input_tensor, weight_tensor, weight_layout, w_shard_vocab_axis, 

161 new_args, new_kwargs, is_args_pad, padding_idx 

162 ) 

163 new_args[0] = mapped_input 

164 

165 # 3. Native Operator Execution 

166 output = func(*new_args, **new_kwargs) 

167 

168 # 4. Erase invalid partial embeddings (Row-Parallel only) 

169 if w_shard_vocab_axis != -1 and input_mask_int is not None: 

170 expanded_mask = input_mask_int[..., None].to(output.dtype) 

171 output = output * expanded_mask 

172 

173 return output 

174 

175 return distributed_embedding_impl