Coverage for hyper_parallel / core / shard / ops / parallel_norm.py: 46%

70 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for TopK operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22class NormDistributedOp(DistributedOp): 

23 """Distributed implementation for Norm operator.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layouts for normalization operator (e.g., RmsNorm). 

28 

29 This method determines the proper output layout for normalization operations 

30 based on the input layouts, ensuring that the normalization operation is 

31 compatible with the distributed training setup. 

32 

33 Args: 

34 layouts (tuple): A tuple of Layout objects representing the input tensor layouts. 

35 Expected to contain at least three layouts: input tensor, gamma parameter, and beta parameter. 

36 extra_args (dict, optional): Additional arguments that might be needed for layout inference. 

37 Defaults to None. 

38 

39 Returns: 

40 tuple: A tuple containing two Layout objects: 

41 - First layout: Layout for the input gradient tensor 

42 - Second layout: Layout for the output tensor 

43 

44 Raises: 

45 ValueError: If the number of input layouts is less than 3. 

46 ValueError: If input layouts are inconsistent. 

47 ValueError: If device matrices of input layouts don't match. 

48 ValueError: If normalization axis is sharded, which is not supported. 

49 ValueError: If gamma parameter layout doesn't match the input layout in normalization dimensions. 

50 ValueError: If input layouts have partial status. 

51 """ 

52 if len(layouts) < 3: 

53 raise ValueError(f"RmsNorm input layouts size {len(layouts)} is less than 3.") 

54 # Check partial inputs 

55 if not self._allow_partial_inputs: 

56 self._check_partial_inputs(layouts) 

57 x_layout = layouts[0] 

58 gamma_layout = layouts[-2] 

59 x_mesh_shape = x_layout.mesh_shape 

60 for i, layout in enumerate(layouts[:-2]): 

61 if layout != x_layout: 

62 raise ValueError(f"RmsNorm inputs must have same layout, but input 0 layout is: {x_layout}," 

63 f"input {i} layout is: {layout}.") 

64 gamma_mesh_shape = gamma_layout.mesh_shape 

65 if x_mesh_shape != gamma_mesh_shape: 

66 raise ValueError("RmsNorm inputs must have same mesh_shape") 

67 x_tensor_map = x_layout.tensor_map 

68 gamma_tensor_map = gamma_layout.tensor_map 

69 begin_norm_axis = len(x_tensor_map) - len(gamma_tensor_map) 

70 for axis in x_tensor_map[begin_norm_axis:]: 

71 if axis == -1: 

72 continue 

73 if isinstance(axis, tuple): 

74 for iaxis in axis: 

75 if iaxis == -1: 

76 continue 

77 if x_mesh_shape[len(x_mesh_shape) - 1 - iaxis] > 1: 

78 raise ValueError(f"RmsNorm is disabled to support the splitting after " 

79 f"begin_norm_axis {begin_norm_axis} for input 0.") 

80 if x_mesh_shape[len(x_mesh_shape) - 1 - axis] > 1: 

81 raise ValueError(f"RmsNorm is disabled to support the splitting after " 

82 f"begin_norm_axis {begin_norm_axis} for input 0.") 

83 if x_tensor_map[begin_norm_axis:] != gamma_tensor_map: 

84 raise ValueError(f"The input sharding in the first {begin_norm_axis} dimensions " 

85 f"{x_layout.alias_tensor_map[begin_norm_axis:]} should equal to" 

86 f" the gamma sharding {gamma_layout.alias_tensor_map}") 

87 output_layout = Layout( 

88 mesh_shape=x_layout.mesh_shape, 

89 alias_name=x_layout.alias_name, 

90 rank_list=x_layout.rank_list 

91 ) 

92 output_map = x_layout.alias_tensor_map[:begin_norm_axis] + ("None",) * len(gamma_tensor_map) 

93 out_layout = output_layout(*output_map) 

94 return x_layout, out_layout 

95 

96class LayerNormDistributedOp(DistributedOp): 

97 """Distributed implementation for torch.nn.functional.layer_norm.""" 

98 

99 def infer_layout(self, layouts, extra_args=None): 

100 """ 

101 Infer output layout for layer_norm. 

102 

103 PyTorch rules: 

104 - normalized_shape specifies the last N dimensions to normalize over. 

105 - All dimensions in normalized_shape MUST be unsharded for correctness. 

106 - Output layout is identical to input layout (shape unchanged). 

107 

108 Args: 

109 layouts (tuple): Layouts of inputs. Expected: 

110 layouts[0] (Layout): Input tensor layout (required). 

111 extra_args (tuple): Should contain 'normalized_shape'. Expected: 

112 extra_args[0] (int | list | tuple): Normalized shape to be unsharded. 

113 

114 Returns: 

115 Layout object representing output tensor layout (same as input if valid). 

116 """ 

117 if not layouts or layouts[0] is None: 

118 raise ValueError("layer_norm requires a valid input tensor layout.") 

119 input_layout = layouts[0] 

120 in_tensor_map = input_layout.tensor_map # e.g., (-1, 0, -1) for 3D tensor 

121 

122 if not extra_args or extra_args[0] is None: 

123 raise ValueError("layer_norm requires normalized_shape in extra_args.") 

124 normalized_shape = extra_args[0] 

125 

126 if isinstance(normalized_shape, int): 

127 normalized_shape = (normalized_shape,) 

128 elif isinstance(normalized_shape, (list, tuple)): 

129 normalized_shape = tuple(normalized_shape) 

130 else: 

131 raise ValueError(f"normalized_shape must be int, list, or tuple, got {type(normalized_shape)}") 

132 

133 input_ndim = len(in_tensor_map) 

134 norm_ndim = len(normalized_shape) 

135 

136 if norm_ndim > input_ndim: 

137 raise ValueError( 

138 f"normalized_shape {normalized_shape} (dims={norm_ndim}) is larger than input ndim={input_ndim}." 

139 ) 

140 

141 # The last `norm_ndim` dimensions are going to be normalized 

142 dims_to_normalize = list(range(input_ndim - norm_ndim, input_ndim)) 

143 

144 # All normalized dims must be unsharded 

145 for dim in dims_to_normalize: 

146 if in_tensor_map[dim] != -1: 

147 raise ValueError( 

148 f"Operation {self.op_name}: Cannot perform sharding on normalized dimension {dim}, " 

149 f"but found sharding assignment: {in_tensor_map[dim]}" 

150 ) 

151 

152 mesh_shape = input_layout.mesh_shape 

153 alias_name = input_layout.alias_name 

154 rank_list = input_layout.rank_list 

155 

156 # Create output layout 

157 def idx_to_alias(idx, aliases): 

158 if idx == -1: 

159 return "None" 

160 return aliases[len(aliases) - idx - 1] 

161 output_map = tuple(idx_to_alias(idx, alias_name) for idx in in_tensor_map) 

162 

163 output_layout = Layout( 

164 mesh_shape=mesh_shape, 

165 alias_name=alias_name, 

166 rank_list=rank_list 

167 ) 

168 output_layout = output_layout(*output_map) 

169 return output_layout