Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_norm.py: 90%
70 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for TopK operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class NormDistributedOp(DistributedOp):
24 """Distributed implementation for Norm operator."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layouts for normalization operator (e.g., RmsNorm).
30 This method determines the proper output layout for normalization operations
31 based on the input layouts, ensuring that the normalization operation is
32 compatible with the distributed training setup.
34 Args:
35 layouts (tuple): A tuple of Layout objects representing the input tensor layouts.
36 Expected to contain at least three layouts: input tensor, gamma parameter, and bias parameter.
37 extra_args (dict, optional): Additional arguments that might be needed for layout inference.
38 Defaults to None.
40 Returns:
41 tuple: A tuple containing two Layout objects:
42 - First layout: Layout for the input gradient tensor
43 - Second layout: Layout for the output tensor
45 Raises:
46 ValueError: If the number of input layouts is less than 3.
47 ValueError: If input layouts are inconsistent.
48 ValueError: If device matrices of input layouts don't match.
49 ValueError: If normalization axis is sharded, which is not supported.
50 ValueError: If gamma parameter layout doesn't match the input layout in normalization dimensions.
51 ValueError: If input layouts have partial status.
52 """
53 if len(layouts) < 3:
54 raise ValueError(f"RmsNorm input layouts size {len(layouts)} is less than 3.")
55 # Check partial inputs
56 if not self._allow_partial_inputs:
57 self._check_partial_inputs(layouts)
58 x_layout = layouts[0]
59 gamma_layout = layouts[-2]
60 x_mesh_shape = x_layout.mesh_shape
61 for i, layout in enumerate(layouts[:-2]):
62 if layout != x_layout:
63 raise ValueError(f"RmsNorm inputs must have same layout, but input 0 layout is: {x_layout},"
64 f"input {i} layout is: {layout}.")
65 gamma_mesh_shape = gamma_layout.mesh_shape
66 if x_mesh_shape != gamma_mesh_shape:
67 raise ValueError("RmsNorm inputs must have same mesh_shape")
68 x_tensor_map = x_layout.tensor_map
69 gamma_tensor_map = gamma_layout.tensor_map
70 begin_norm_axis = len(x_tensor_map) - len(gamma_tensor_map)
71 for axis in x_tensor_map[begin_norm_axis:]:
72 if axis == -1:
73 continue
74 if isinstance(axis, tuple):
75 for iaxis in axis:
76 if iaxis == -1:
77 continue
78 if x_mesh_shape[len(x_mesh_shape) - 1 - iaxis] > 1:
79 raise ValueError(f"RmsNorm is disabled to support the splitting after "
80 f"begin_norm_axis {begin_norm_axis} for input 0.")
81 if x_mesh_shape[len(x_mesh_shape) - 1 - axis] > 1:
82 raise ValueError(f"RmsNorm is disabled to support the splitting after "
83 f"begin_norm_axis {begin_norm_axis} for input 0.")
84 if x_tensor_map[begin_norm_axis:] != gamma_tensor_map:
85 raise ValueError(f"The input sharding in the first {begin_norm_axis} dimensions "
86 f"{x_layout.alias_tensor_map[begin_norm_axis:]} should equal to"
87 f" the gamma sharding {gamma_layout.alias_tensor_map}")
88 output_layout = Layout(
89 mesh_shape=x_layout.mesh_shape,
90 alias_name=x_layout.alias_name,
91 rank_list=x_layout.rank_list
92 )
93 output_map = x_layout.alias_tensor_map[:begin_norm_axis] + ("None",) * len(gamma_tensor_map)
94 out_layout = output_layout(*output_map)
95 return x_layout, out_layout
98class LayerNormDistributedOp(DistributedOp):
99 """Distributed implementation for torch.nn.functional.layer_norm."""
101 def infer_layout(self, layouts, extra_args=None):
102 """
103 Infer output layout for layer_norm.
105 PyTorch rules:
106 - normalized_shape specifies the last N dimensions to normalize over.
107 - All dimensions in normalized_shape MUST be unsharded for correctness.
108 - Output layout is identical to input layout (shape unchanged).
110 Args:
111 layouts (tuple): Layouts of inputs. Expected:
112 layouts[0] (Layout): Input tensor layout (required).
113 extra_args (tuple): Should contain 'normalized_shape'. Expected:
114 extra_args[0] (int | list | tuple): Normalized shape to be unsharded.
116 Returns:
117 Layout object representing output tensor layout (same as input if valid).
118 """
119 if not layouts or layouts[0] is None:
120 raise ValueError("layer_norm requires a valid input tensor layout.")
121 input_layout = layouts[0]
122 in_tensor_map = input_layout.tensor_map # e.g., (-1, 0, -1) for 3D tensor
124 if not extra_args or extra_args[0] is None:
125 raise ValueError("layer_norm requires normalized_shape in extra_args.")
126 normalized_shape = extra_args[0]
128 if isinstance(normalized_shape, int):
129 normalized_shape = (normalized_shape,)
130 elif isinstance(normalized_shape, (list, tuple)):
131 normalized_shape = tuple(normalized_shape)
132 else:
133 raise ValueError(f"normalized_shape must be int, list, or tuple, got {type(normalized_shape)}")
135 input_ndim = len(in_tensor_map)
136 norm_ndim = len(normalized_shape)
138 if norm_ndim > input_ndim:
139 raise ValueError(
140 f"normalized_shape {normalized_shape} (dims={norm_ndim}) is larger than input ndim={input_ndim}."
141 )
143 # The last `norm_ndim` dimensions are going to be normalized
144 dims_to_normalize = list(range(input_ndim - norm_ndim, input_ndim))
146 # All normalized dims must be unsharded
147 for dim in dims_to_normalize:
148 if in_tensor_map[dim] != -1:
149 raise ValueError(
150 f"Operation {self.op_name}: Cannot perform sharding on normalized dimension {dim}, "
151 f"but found sharding assignment: {in_tensor_map[dim]}"
152 )
154 mesh_shape = input_layout.mesh_shape
155 alias_name = input_layout.alias_name
156 rank_list = input_layout.rank_list
158 # Create output layout
159 def idx_to_alias(idx, aliases):
160 if idx == -1:
161 return "None"
162 return aliases[len(aliases) - idx - 1]
163 output_map = tuple(idx_to_alias(idx, alias_name) for idx in in_tensor_map)
165 output_layout = Layout(
166 mesh_shape=mesh_shape,
167 alias_name=alias_name,
168 rank_list=rank_list
169 )
170 output_layout = output_layout(*output_map)
171 return output_layout