Coverage for hyper_parallel / core / shard / ops / parallel_squeeze.py: 78%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for ExpandDims operator. 

17""" 

18from hyper_parallel.core.layout import Layout 

19from .parallel_ops import DistributedOp 

20 

21 

22class SqueezeDistributedOp(DistributedOp): 

23 """Distributed implementation for Squeeze operator.""" 

24 

25 def infer_layout(self, layouts, extra_args): 

26 """ 

27 Infer output layout for Squeeze. 

28 

29 Args: 

30 layouts (tuple): Tuple containing input layout. 

31 extra_args: Extra arguments containing axis and input_shapes. 

32 Can be dict or list/tuple where last element is input_shapes. 

33 

34 Returns: 

35 Layout: Output layout with squeezed dimensions removed. 

36 """ 

37 if not layouts: 

38 raise ValueError( 

39 f"For {self.op_name}, layouts should contain at least one input layout, " 

40 f"but got empty layouts." 

41 ) 

42 

43 x_layout = layouts[0] 

44 if x_layout.mesh_shape is None: 

45 raise ValueError( 

46 f"For {self.op_name}, input layout mesh_shape should not be None, " 

47 f"but got None." 

48 ) 

49 

50 axis, input_shape = self._extract_args(extra_args) 

51 if input_shape is None: 

52 raise ValueError( 

53 f"For {self.op_name}, input_shapes should be provided in extra_args, " 

54 f"but got None." 

55 ) 

56 

57 return self._compute_squeeze_layout(x_layout, axis, input_shape) 

58 

59 def _extract_args(self, extra_args): 

60 """Extract axis and input_shape from extra_args.""" 

61 if isinstance(extra_args, dict): 

62 input_shapes = extra_args.get("input_shapes", None) 

63 axis = extra_args.get("axis", None) 

64 elif isinstance(extra_args, (list, tuple)) and extra_args: 

65 # Last element is input_shapes 

66 input_shapes = extra_args[-1] 

67 if not isinstance(input_shapes, (list, tuple)): 

68 raise ValueError( 

69 f"For {self.op_name}, input_shapes should be list or tuple, " 

70 f"but got {type(input_shapes)}." 

71 ) 

72 # First element is axis (if available) 

73 axis = extra_args[0] if len(extra_args) > 1 else None 

74 else: 

75 raise ValueError( 

76 f"For {self.op_name}, extra_args should be dict or list/tuple, " 

77 f"but got {type(extra_args)}." 

78 ) 

79 

80 # Get input shape (first element of input_shapes) 

81 if input_shapes: 

82 input_shape = input_shapes[0] if isinstance(input_shapes[0], (list, tuple)) else input_shapes 

83 else: 

84 input_shape = None 

85 

86 return axis, input_shape 

87 

88 def _compute_squeeze_layout(self, x_layout, axis, input_shape): 

89 """Compute the squeezed layout.""" 

90 # Handle scalar case 

91 if not input_shape: 

92 return self._handle_scalar_case(x_layout, axis) 

93 

94 # Validate input_shape matches layout rank 

95 self._validate_input_shape(x_layout, input_shape) 

96 

97 # Find dimensions to squeeze 

98 dims_to_squeeze = self._get_dims_to_squeeze(x_layout, axis, input_shape) 

99 

100 # Create output layout 

101 return self._create_output_layout(x_layout, dims_to_squeeze) 

102 

103 def _handle_scalar_case(self, x_layout, axis): 

104 """Handle scalar input case.""" 

105 if axis is not None and axis != [] and axis != (): 

106 raise ValueError( 

107 f"For {self.op_name}, axis should be None for scalar input, " 

108 f"but got {axis}." 

109 ) 

110 

111 # Return scalar layout 

112 output_layout = Layout( 

113 mesh_shape=x_layout.mesh_shape, 

114 alias_name=x_layout.alias_name, 

115 rank_list=x_layout.rank_list 

116 ) 

117 output_layout = output_layout() 

118 return output_layout 

119 

120 def _validate_input_shape(self, x_layout, input_shape): 

121 """Validate that input shape matches layout rank.""" 

122 x_map = list(x_layout.alias_tensor_map) 

123 in_rank = len(x_map) 

124 

125 if len(input_shape) != in_rank: 

126 raise ValueError( 

127 f"For {self.op_name}, input shape rank should match layout rank, " 

128 f"but got {len(input_shape)} and {in_rank}." 

129 ) 

130 

131 def _get_dims_to_squeeze(self, x_layout, axis, input_shape): 

132 """Get list of dimensions to squeeze.""" 

133 x_map = list(x_layout.alias_tensor_map) 

134 in_rank = len(x_map) 

135 

136 if axis is None: 

137 return self._get_all_squeezable_dims(x_map, input_shape) 

138 return self._get_specified_dims_to_squeeze(x_map, axis, input_shape, in_rank) 

139 

140 def _get_all_squeezable_dims(self, x_map, input_shape): 

141 """Get all squeezable dimensions when axis is None.""" 

142 dims_to_squeeze = [] 

143 for i, shape in enumerate(input_shape): 

144 if shape == 1 and x_map[i] == "None": 

145 dims_to_squeeze.append(i) 

146 return dims_to_squeeze 

147 

148 def _get_specified_dims_to_squeeze(self, x_map, axis, input_shape, in_rank): 

149 """Get dimensions to squeeze when axis is specified.""" 

150 # Convert axis to list if it's a single integer 

151 if isinstance(axis, int): 

152 axis = [axis] 

153 

154 # Convert negative indices to positive 

155 axis = [ax if ax >= 0 else ax + in_rank for ax in axis] 

156 

157 # Validate axis range 

158 self._validate_axis_range(axis, in_rank) 

159 

160 # Check all specified axes 

161 for ax in axis: 

162 self._validate_axis_for_squeeze(x_map, input_shape, ax) 

163 

164 # Return sorted unique axes 

165 return sorted(set(axis)) 

166 

167 def _validate_axis_range(self, axis, in_rank): 

168 """Validate axis values are within range.""" 

169 for ax in axis: 

170 if ax < 0 or ax >= in_rank: 

171 raise ValueError( 

172 f"For {self.op_name}, axis should be in range [{-in_rank}, {in_rank-1}], " 

173 f"but got {ax}." 

174 ) 

175 

176 def _validate_axis_for_squeeze(self, x_map, input_shape, ax): 

177 """Validate a specific axis can be squeezed.""" 

178 # Check shape == 1 

179 if input_shape[ax] != 1: 

180 raise ValueError( 

181 f"For {self.op_name}, dimension should have size 1, " 

182 f"but got shape {input_shape[ax]} at dimension {ax}." 

183 ) 

184 

185 # Check mapping is "None" (not distributed) 

186 if x_map[ax] != "None": 

187 raise ValueError( 

188 f"For {self.op_name}, dimension should not be distributed, " 

189 f"but got dimension {ax} mapped to device axis {x_map[ax]}." 

190 ) 

191 

192 def _create_output_layout(self, x_layout, dims_to_squeeze): 

193 """Create output layout after squeezing dimensions.""" 

194 # Get current alias tensor map 

195 x_map = list(x_layout.alias_tensor_map) 

196 

197 # Sort in descending order for safe removal 

198 dims_to_squeeze = sorted(set(dims_to_squeeze), reverse=True) 

199 

200 # Remove specified dimensions 

201 for dim in dims_to_squeeze: 

202 del x_map[dim] 

203 

204 new_map = x_map 

205 

206 # Create output layout with new mapping 

207 output_layout = Layout( 

208 mesh_shape=x_layout.mesh_shape, 

209 alias_name=x_layout.alias_name, 

210 rank_list=x_layout.rank_list 

211 ) 

212 

213 if new_map: 

214 output_layout = output_layout(*new_map) 

215 else: 

216 # For scalar result 

217 output_layout = output_layout() 

218 

219 # Copy partial operations from input layout 

220 self._copy_partial_operations(x_layout, output_layout, new_map) 

221 

222 return output_layout 

223 

224 def _copy_partial_operations(self, x_layout, output_layout, new_map): 

225 """Copy partial operations from input to output layout.""" 

226 for i, partial_op in enumerate(x_layout.partial): 

227 if partial_op is not None: 

228 dev_axis_name = x_layout.alias_name[i] 

229 # Check if this device axis is still used in the output 

230 if dev_axis_name in new_map: 

231 output_layout.set_partial_by_dev_axis(dev_axis_name, partial_op)