Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_norm.py: 90%

70 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for TopK operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class NormDistributedOp(DistributedOp): 

24 """Distributed implementation for Norm operator.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layouts for normalization operator (e.g., RmsNorm). 

29 

30 This method determines the proper output layout for normalization operations 

31 based on the input layouts, ensuring that the normalization operation is 

32 compatible with the distributed training setup. 

33 

34 Args: 

35 layouts (tuple): A tuple of Layout objects representing the input tensor layouts. 

36 Expected to contain at least three layouts: input tensor, gamma parameter, and bias parameter. 

37 extra_args (dict, optional): Additional arguments that might be needed for layout inference. 

38 Defaults to None. 

39 

40 Returns: 

41 tuple: A tuple containing two Layout objects: 

42 - First layout: Layout for the input gradient tensor 

43 - Second layout: Layout for the output tensor 

44 

45 Raises: 

46 ValueError: If the number of input layouts is less than 3. 

47 ValueError: If input layouts are inconsistent. 

48 ValueError: If device matrices of input layouts don't match. 

49 ValueError: If normalization axis is sharded, which is not supported. 

50 ValueError: If gamma parameter layout doesn't match the input layout in normalization dimensions. 

51 ValueError: If input layouts have partial status. 

52 """ 

53 if len(layouts) < 3: 

54 raise ValueError(f"RmsNorm input layouts size {len(layouts)} is less than 3.") 

55 # Check partial inputs 

56 if not self._allow_partial_inputs: 

57 self._check_partial_inputs(layouts) 

58 x_layout = layouts[0] 

59 gamma_layout = layouts[-2] 

60 x_mesh_shape = x_layout.mesh_shape 

61 for i, layout in enumerate(layouts[:-2]): 

62 if layout != x_layout: 

63 raise ValueError(f"RmsNorm inputs must have same layout, but input 0 layout is: {x_layout}," 

64 f"input {i} layout is: {layout}.") 

65 gamma_mesh_shape = gamma_layout.mesh_shape 

66 if x_mesh_shape != gamma_mesh_shape: 

67 raise ValueError("RmsNorm inputs must have same mesh_shape") 

68 x_tensor_map = x_layout.tensor_map 

69 gamma_tensor_map = gamma_layout.tensor_map 

70 begin_norm_axis = len(x_tensor_map) - len(gamma_tensor_map) 

71 for axis in x_tensor_map[begin_norm_axis:]: 

72 if axis == -1: 

73 continue 

74 if isinstance(axis, tuple): 

75 for iaxis in axis: 

76 if iaxis == -1: 

77 continue 

78 if x_mesh_shape[len(x_mesh_shape) - 1 - iaxis] > 1: 

79 raise ValueError(f"RmsNorm is disabled to support the splitting after " 

80 f"begin_norm_axis {begin_norm_axis} for input 0.") 

81 if x_mesh_shape[len(x_mesh_shape) - 1 - axis] > 1: 

82 raise ValueError(f"RmsNorm is disabled to support the splitting after " 

83 f"begin_norm_axis {begin_norm_axis} for input 0.") 

84 if x_tensor_map[begin_norm_axis:] != gamma_tensor_map: 

85 raise ValueError(f"The input sharding in the first {begin_norm_axis} dimensions " 

86 f"{x_layout.alias_tensor_map[begin_norm_axis:]} should equal to" 

87 f" the gamma sharding {gamma_layout.alias_tensor_map}") 

88 output_layout = Layout( 

89 mesh_shape=x_layout.mesh_shape, 

90 alias_name=x_layout.alias_name, 

91 rank_list=x_layout.rank_list 

92 ) 

93 output_map = x_layout.alias_tensor_map[:begin_norm_axis] + ("None",) * len(gamma_tensor_map) 

94 out_layout = output_layout(*output_map) 

95 return x_layout, out_layout 

96 

97 

98class LayerNormDistributedOp(DistributedOp): 

99 """Distributed implementation for torch.nn.functional.layer_norm.""" 

100 

101 def infer_layout(self, layouts, extra_args=None): 

102 """ 

103 Infer output layout for layer_norm. 

104 

105 PyTorch rules: 

106 - normalized_shape specifies the last N dimensions to normalize over. 

107 - All dimensions in normalized_shape MUST be unsharded for correctness. 

108 - Output layout is identical to input layout (shape unchanged). 

109 

110 Args: 

111 layouts (tuple): Layouts of inputs. Expected: 

112 layouts[0] (Layout): Input tensor layout (required). 

113 extra_args (tuple): Should contain 'normalized_shape'. Expected: 

114 extra_args[0] (int | list | tuple): Normalized shape to be unsharded. 

115 

116 Returns: 

117 Layout object representing output tensor layout (same as input if valid). 

118 """ 

119 if not layouts or layouts[0] is None: 

120 raise ValueError("layer_norm requires a valid input tensor layout.") 

121 input_layout = layouts[0] 

122 in_tensor_map = input_layout.tensor_map # e.g., (-1, 0, -1) for 3D tensor 

123 

124 if not extra_args or extra_args[0] is None: 

125 raise ValueError("layer_norm requires normalized_shape in extra_args.") 

126 normalized_shape = extra_args[0] 

127 

128 if isinstance(normalized_shape, int): 

129 normalized_shape = (normalized_shape,) 

130 elif isinstance(normalized_shape, (list, tuple)): 

131 normalized_shape = tuple(normalized_shape) 

132 else: 

133 raise ValueError(f"normalized_shape must be int, list, or tuple, got {type(normalized_shape)}") 

134 

135 input_ndim = len(in_tensor_map) 

136 norm_ndim = len(normalized_shape) 

137 

138 if norm_ndim > input_ndim: 

139 raise ValueError( 

140 f"normalized_shape {normalized_shape} (dims={norm_ndim}) is larger than input ndim={input_ndim}." 

141 ) 

142 

143 # The last `norm_ndim` dimensions are going to be normalized 

144 dims_to_normalize = list(range(input_ndim - norm_ndim, input_ndim)) 

145 

146 # All normalized dims must be unsharded 

147 for dim in dims_to_normalize: 

148 if in_tensor_map[dim] != -1: 

149 raise ValueError( 

150 f"Operation {self.op_name}: Cannot perform sharding on normalized dimension {dim}, " 

151 f"but found sharding assignment: {in_tensor_map[dim]}" 

152 ) 

153 

154 mesh_shape = input_layout.mesh_shape 

155 alias_name = input_layout.alias_name 

156 rank_list = input_layout.rank_list 

157 

158 # Create output layout 

159 def idx_to_alias(idx, aliases): 

160 if idx == -1: 

161 return "None" 

162 return aliases[len(aliases) - idx - 1] 

163 output_map = tuple(idx_to_alias(idx, alias_name) for idx in in_tensor_map) 

164 

165 output_layout = Layout( 

166 mesh_shape=mesh_shape, 

167 alias_name=alias_name, 

168 rank_list=rank_list 

169 ) 

170 output_layout = output_layout(*output_map) 

171 return output_layout