Coverage for hyper_parallel / core / shard / ops / parallel_norm.py: 46%
70 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for TopK operator.
17"""
19from hyper_parallel.core.layout import Layout
20from .parallel_ops import DistributedOp
22class NormDistributedOp(DistributedOp):
23 """Distributed implementation for Norm operator."""
25 def infer_layout(self, layouts, extra_args=None):
26 """
27 Infer output layouts for normalization operator (e.g., RmsNorm).
29 This method determines the proper output layout for normalization operations
30 based on the input layouts, ensuring that the normalization operation is
31 compatible with the distributed training setup.
33 Args:
34 layouts (tuple): A tuple of Layout objects representing the input tensor layouts.
35 Expected to contain at least three layouts: input tensor, gamma parameter, and beta parameter.
36 extra_args (dict, optional): Additional arguments that might be needed for layout inference.
37 Defaults to None.
39 Returns:
40 tuple: A tuple containing two Layout objects:
41 - First layout: Layout for the input gradient tensor
42 - Second layout: Layout for the output tensor
44 Raises:
45 ValueError: If the number of input layouts is less than 3.
46 ValueError: If input layouts are inconsistent.
47 ValueError: If device matrices of input layouts don't match.
48 ValueError: If normalization axis is sharded, which is not supported.
49 ValueError: If gamma parameter layout doesn't match the input layout in normalization dimensions.
50 ValueError: If input layouts have partial status.
51 """
52 if len(layouts) < 3:
53 raise ValueError(f"RmsNorm input layouts size {len(layouts)} is less than 3.")
54 # Check partial inputs
55 if not self._allow_partial_inputs:
56 self._check_partial_inputs(layouts)
57 x_layout = layouts[0]
58 gamma_layout = layouts[-2]
59 x_mesh_shape = x_layout.mesh_shape
60 for i, layout in enumerate(layouts[:-2]):
61 if layout != x_layout:
62 raise ValueError(f"RmsNorm inputs must have same layout, but input 0 layout is: {x_layout},"
63 f"input {i} layout is: {layout}.")
64 gamma_mesh_shape = gamma_layout.mesh_shape
65 if x_mesh_shape != gamma_mesh_shape:
66 raise ValueError("RmsNorm inputs must have same mesh_shape")
67 x_tensor_map = x_layout.tensor_map
68 gamma_tensor_map = gamma_layout.tensor_map
69 begin_norm_axis = len(x_tensor_map) - len(gamma_tensor_map)
70 for axis in x_tensor_map[begin_norm_axis:]:
71 if axis == -1:
72 continue
73 if isinstance(axis, tuple):
74 for iaxis in axis:
75 if iaxis == -1:
76 continue
77 if x_mesh_shape[len(x_mesh_shape) - 1 - iaxis] > 1:
78 raise ValueError(f"RmsNorm is disabled to support the splitting after "
79 f"begin_norm_axis {begin_norm_axis} for input 0.")
80 if x_mesh_shape[len(x_mesh_shape) - 1 - axis] > 1:
81 raise ValueError(f"RmsNorm is disabled to support the splitting after "
82 f"begin_norm_axis {begin_norm_axis} for input 0.")
83 if x_tensor_map[begin_norm_axis:] != gamma_tensor_map:
84 raise ValueError(f"The input sharding in the first {begin_norm_axis} dimensions "
85 f"{x_layout.alias_tensor_map[begin_norm_axis:]} should equal to"
86 f" the gamma sharding {gamma_layout.alias_tensor_map}")
87 output_layout = Layout(
88 mesh_shape=x_layout.mesh_shape,
89 alias_name=x_layout.alias_name,
90 rank_list=x_layout.rank_list
91 )
92 output_map = x_layout.alias_tensor_map[:begin_norm_axis] + ("None",) * len(gamma_tensor_map)
93 out_layout = output_layout(*output_map)
94 return x_layout, out_layout
96class LayerNormDistributedOp(DistributedOp):
97 """Distributed implementation for torch.nn.functional.layer_norm."""
99 def infer_layout(self, layouts, extra_args=None):
100 """
101 Infer output layout for layer_norm.
103 PyTorch rules:
104 - normalized_shape specifies the last N dimensions to normalize over.
105 - All dimensions in normalized_shape MUST be unsharded for correctness.
106 - Output layout is identical to input layout (shape unchanged).
108 Args:
109 layouts (tuple): Layouts of inputs. Expected:
110 layouts[0] (Layout): Input tensor layout (required).
111 extra_args (tuple): Should contain 'normalized_shape'. Expected:
112 extra_args[0] (int | list | tuple): Normalized shape to be unsharded.
114 Returns:
115 Layout object representing output tensor layout (same as input if valid).
116 """
117 if not layouts or layouts[0] is None:
118 raise ValueError("layer_norm requires a valid input tensor layout.")
119 input_layout = layouts[0]
120 in_tensor_map = input_layout.tensor_map # e.g., (-1, 0, -1) for 3D tensor
122 if not extra_args or extra_args[0] is None:
123 raise ValueError("layer_norm requires normalized_shape in extra_args.")
124 normalized_shape = extra_args[0]
126 if isinstance(normalized_shape, int):
127 normalized_shape = (normalized_shape,)
128 elif isinstance(normalized_shape, (list, tuple)):
129 normalized_shape = tuple(normalized_shape)
130 else:
131 raise ValueError(f"normalized_shape must be int, list, or tuple, got {type(normalized_shape)}")
133 input_ndim = len(in_tensor_map)
134 norm_ndim = len(normalized_shape)
136 if norm_ndim > input_ndim:
137 raise ValueError(
138 f"normalized_shape {normalized_shape} (dims={norm_ndim}) is larger than input ndim={input_ndim}."
139 )
141 # The last `norm_ndim` dimensions are going to be normalized
142 dims_to_normalize = list(range(input_ndim - norm_ndim, input_ndim))
144 # All normalized dims must be unsharded
145 for dim in dims_to_normalize:
146 if in_tensor_map[dim] != -1:
147 raise ValueError(
148 f"Operation {self.op_name}: Cannot perform sharding on normalized dimension {dim}, "
149 f"but found sharding assignment: {in_tensor_map[dim]}"
150 )
152 mesh_shape = input_layout.mesh_shape
153 alias_name = input_layout.alias_name
154 rank_list = input_layout.rank_list
156 # Create output layout
157 def idx_to_alias(idx, aliases):
158 if idx == -1:
159 return "None"
160 return aliases[len(aliases) - idx - 1]
161 output_map = tuple(idx_to_alias(idx, alias_name) for idx in in_tensor_map)
163 output_layout = Layout(
164 mesh_shape=mesh_shape,
165 alias_name=alias_name,
166 rank_list=rank_list
167 )
168 output_layout = output_layout(*output_map)
169 return output_layout