Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_embedding.py: 65%
77 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for the embedding operator.
17This module handles Tensor Parallelism for Embedding layers (Column/Row Parallel).
18"""
20from hyper_parallel.core.dtensor.layout import Layout
21from .parallel_ops import DistributedOp
23class EmbeddingDistributedOp(DistributedOp):
24 """
25 Distributed implementation for embedding operators.
26 Supports Column Parallelism (CP) and Row Parallelism (RP).
27 """
29 def infer_layout(self, layouts, extra_args=None):
30 """
31 Infers the output layout (sharding state) based on input and weight layouts.
32 """
33 if not layouts or len(layouts) < 2:
34 raise ValueError(
35 f"Operation {self.op_name}: requires both input and weight layouts."
36 )
38 input_layout = layouts[0]
39 weight_layout = layouts[1]
40 weight_tensor_map = weight_layout.tensor_map
42 # weight_tensor_map: [vocab_size_dim, embed_dim_dim]
43 w_shard_vocab_axis = weight_tensor_map[0]
44 w_shard_embed_axis = weight_tensor_map[1]
46 # Output shape is [*input_shape, embed_dim]
47 output_tensor_map = list(input_layout.tensor_map)
48 output_tensor_map.append(w_shard_embed_axis)
50 output_layout = Layout(
51 mesh_shape=input_layout.mesh_shape,
52 alias_name=input_layout.alias_name,
53 rank_list=input_layout.rank_list
54 )
55 output_layout.set_tensor_map(tuple(output_tensor_map))
57 # If vocab is sharded (Row Parallelism), output is in Partial Sum state
58 if w_shard_vocab_axis != -1:
59 # pylint: disable=protected-access
60 output_layout._partial = list(input_layout.partial)
61 vocab_axis_name = input_layout.alias_name[len(input_layout.alias_name) - 1 - w_shard_vocab_axis]
62 output_layout.set_partial_by_dev_axis(vocab_axis_name, 'sum')
63 # pylint: disable=protected-access
64 output_layout._alias_tensor_map = output_layout._build_readable_tensor_map()
65 # pylint: disable=protected-access
66 output_layout.tensor_map_to_placement()
67 output_layout.update_compact_str()
69 return output_layout
71 def _parse_params(self, args, kwargs):
72 """
73 Extracts padding_idx, max_norm, and scale_grad from args or kwargs.
74 F.embedding signature: (input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq...)
75 """
76 padding_idx = args[2] if len(args) > 2 else kwargs.get('padding_idx', None)
77 max_norm = args[3] if len(args) > 3 else kwargs.get('max_norm', None)
78 scale_grad = args[5] if len(args) > 5 else kwargs.get('scale_grad_by_freq', False)
79 return padding_idx, max_norm, scale_grad
81 def _handle_rp_input(self, input_tensor, weight_tensor, weight_layout, w_shard_vocab_axis,
82 new_args, kwargs, is_args_pad, padding_idx):
83 """
84 Processes Row Parallelism input: shifts indices, handles padding, and generates masks.
85 """
86 mesh = weight_layout.mesh
87 mesh_dim_idx = len(mesh.mesh_shape) - 1 - w_shard_vocab_axis
88 vocab_coord = mesh.get_local_rank(mesh_dim_idx)
90 vocab_size_per_partition = weight_tensor.shape[0]
91 vocab_start_index = int(vocab_coord * vocab_size_per_partition)
92 vocab_end_index = int(vocab_start_index + vocab_size_per_partition)
94 # Map global padding_idx to local rank range
95 if padding_idx is not None:
96 if vocab_start_index <= padding_idx < vocab_end_index:
97 mapped_padding_idx = int(padding_idx - vocab_start_index)
98 if is_args_pad:
99 new_args[2] = mapped_padding_idx
100 else:
101 kwargs['padding_idx'] = mapped_padding_idx
102 else:
103 if is_args_pad:
104 new_args[2] = None
105 else:
106 kwargs.pop('padding_idx', None)
108 # Calculate out-of-bounds mask
109 mask = (input_tensor >= vocab_start_index) & (input_tensor < vocab_end_index)
111 # Cross-platform cast to matching int dtype
112 mask_int = mask.to(input_tensor.dtype) if hasattr(mask, "to") else mask.astype(input_tensor.dtype)
114 # Shift global indices to local range using native scalar broadcast
115 local_input = input_tensor - vocab_start_index
117 # Zero out invalid indices mathematically instead of using .where() or clamp().
118 # This prevents NPU out-of-bounds memory access during the embedding lookup
119 # while keeping the code perfectly backend-neutral.
120 local_input = local_input * mask_int
122 return local_input, mask_int
124 def get_expand_impl(self, func, infer_result, layouts, extra_args=None):
125 """
126 Returns the execution implementation wrapper.
127 Helper functions are used to keep Cyclomatic Complexity (CCN) low.
128 """
129 weight_layout = layouts[1]
130 w_shard_vocab_axis = weight_layout.tensor_map[0]
131 w_shard_embed_axis = weight_layout.tensor_map[1]
133 # Use native implementation if no weight sharding is applied
134 if w_shard_vocab_axis == -1 and w_shard_embed_axis == -1:
135 return None
137 def distributed_embedding_impl(*args, **kwargs):
138 input_tensor, weight_tensor = args[0], args[1]
139 new_args, new_kwargs = list(args), kwargs.copy()
141 # 1. Parameter extraction and validation
142 padding_idx, max_norm, scale_grad = self._parse_params(args, kwargs)
144 # Check for max_norm with specific error messages for CP and RP
145 if max_norm is not None:
146 if w_shard_embed_axis != -1:
147 raise ValueError("Column-Parallel Embedding does not support `max_norm` parameter.")
148 if w_shard_vocab_axis != -1:
149 raise ValueError("Row-Parallel Embedding does not support `max_norm` parameter.")
151 # Check for scale_grad_by_freq with RP
152 if scale_grad and w_shard_vocab_axis != -1:
153 raise ValueError("Row-Parallel Embedding does not support `scale_grad_by_freq=True`.")
155 # 2. Row Parallel Processing
156 input_mask_int = None
157 if w_shard_vocab_axis != -1:
158 is_args_pad = len(args) > 2
159 mapped_input, input_mask_int = self._handle_rp_input(
160 input_tensor, weight_tensor, weight_layout, w_shard_vocab_axis,
161 new_args, new_kwargs, is_args_pad, padding_idx
162 )
163 new_args[0] = mapped_input
165 # 3. Native Operator Execution
166 output = func(*new_args, **new_kwargs)
168 # 4. Erase invalid partial embeddings (Row-Parallel only)
169 if w_shard_vocab_axis != -1 and input_mask_int is not None:
170 expanded_mask = input_mask_int[..., None].to(output.dtype)
171 output = output * expanded_mask
173 return output
175 return distributed_embedding_impl