Coverage for hyper_parallel / core / shard / ops / parallel_expand.py: 86%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Expand operator. 

17""" 

18 

19from hyper_parallel.core.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22class ExpandDistributedOp(DistributedOp): 

23 """Distributed implementation for torch.Tensor.expand.""" 

24 

25 def infer_layout(self, layouts, extra_args=None): 

26 """ 

27 Infer output layout for torch.Tensor.expand. 

28 

29 PyTorch semantics: 

30 - Expands singleton dimensions (size 1) to larger sizes 

31 - Passing -1 preserves the original size of that dimension 

32 - Only dimensions with global size 1 can be expanded 

33 - Existing dimensions being expanded MUST be unsharded: 

34 

35 Args: 

36 layouts (tuple): Layouts of inputs. Expected: 

37 layouts[0] (Layout): Input tensor layout (required). 

38 extra_args (tuple): Should contain 'sizes'. Expected: 

39 extra_args[0] (int): One element in desired expanded sizes (required). 

40 ... 

41 extra_args[n] (int): One element in desired expanded sizes (required). 

42 

43 Returns: 

44 Layout: Output tensor layout with: 

45 - New dimensions: unsharded (-1) 

46 - Expanded existing dimensions: unsharded (-1) 

47 - Preserved dimensions (-1 in sizes): original sharding preserved 

48 """ 

49 if not layouts or layouts[0] is None: 

50 raise ValueError( 

51 f"Operation {self.op_name}: expand requires a valid input tensor layout." 

52 ) 

53 input_layout = layouts[0] 

54 in_tensor_map = input_layout.tensor_map 

55 input_ndim = len(in_tensor_map) 

56 

57 if not extra_args or len(extra_args) < 1: 

58 raise ValueError( 

59 f"Operation {self.op_name}: expand requires 'sizes' parameter in extra_args." 

60 ) 

61 output_ndim = len(extra_args) 

62 

63 # Normalize sizes to tuple 

64 sizes = [] 

65 for i in range(output_ndim): 

66 if not isinstance(extra_args[i], int): 

67 raise ValueError( 

68 f"Operation {self.op_name}: elements in 'sizes' parameter must be int." 

69 ) 

70 sizes.append(extra_args[i]) 

71 sizes = tuple(sizes) 

72 

73 # output_ndim = len(sizes) 

74 num_new_dims = output_ndim - input_ndim 

75 

76 # PyTorch only allows prepending new dimensions (not inserting in middle) 

77 if num_new_dims < 0: 

78 raise ValueError( 

79 f"Operation {self.op_name}: Cannot reduce dimensions with expand. " 

80 f"Input has {input_ndim} dims, requested {output_ndim} dims." 

81 ) 

82 

83 # Build output tensor map 

84 output_map = [] 

85 

86 # Rule 1: For the new dimensions, the size cannot be set to -1. 

87 for i in range(num_new_dims): 

88 if sizes[i] == -1: 

89 raise ValueError( 

90 f"Operation {self.op_name}: Cannot use -1 for new dimension at position {i}. " 

91 ) 

92 output_map.append(-1) # Always unsharded 

93 

94 # Rule 2: Process existing dimensions 

95 for i in range(input_ndim): 

96 output_dim_idx = num_new_dims + i 

97 requested_size = sizes[output_dim_idx] 

98 

99 if requested_size == -1: 

100 # keep original sharding 

101 output_map.append(in_tensor_map[i]) 

102 else: 

103 # Cannot expand dimension which is sharded 

104 if in_tensor_map[i] != -1: 

105 raise ValueError( 

106 f"Operation {self.op_name}: Cannot expand dimension {i} which is sharded." 

107 ) 

108 # Expanded dimension becomes unsharded in output 

109 output_map.append(-1) 

110 

111 # Construct output layout 

112 mesh_shape = input_layout.mesh_shape 

113 alias_name = input_layout.alias_name 

114 rank_list = input_layout.rank_list 

115 

116 def idx_to_alias(idx, aliases): 

117 if idx == -1: 

118 return "None" 

119 return aliases[len(aliases) - idx - 1] 

120 

121 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map) 

122 

123 output_layout = Layout( 

124 mesh_shape=mesh_shape, 

125 alias_name=alias_name, 

126 rank_list=rank_list 

127 ) 

128 output_layout = output_layout(*output_alias_map) 

129 return output_layout 

130 

131 

132class ExpandAsDistributedOp(DistributedOp): 

133 """Distributed implementation for torch.Tensor.expand_as.""" 

134 

135 def infer_layout(self, layouts, extra_args=None): 

136 """ 

137 Infer output layout for expand_as. 

138 

139 PyTorch semantics: 

140 - Only dimensions with global size == 1 can be expanded to larger sizes 

141 - Dimensions with size > 1 must exactly match between input and target 

142 - Broadcast replicates a single value across the expanded dimension 

143 

144 Critical sharding constraints: 

145 - Input dimensions with global size == 1 MUST be unsharded (-1) 

146 - When expanding a dimension (size 1 → N), this dimension must be unsharded in input layout 

147 - Expanded dimensions become unsharded in output 

148 - Non-expanded dimensions preserve their input sharding pattern 

149 

150 Args: 

151 layouts (tuple): Layouts of inputs. Expected: 

152 layouts[0] (Layout): Input tensor layout (required). 

153 layouts[1] (Layout): Target tensor layout (No need). 

154 extra_args (tuple): Must contain shape information. Expected: 

155 extra_args[0][0] (tuple of int): Input global shape. 

156 extra_args[0][1] (tuple of int): Target global shape. 

157  

158 Returns: 

159 Layout: Output tensor layout with sharding preserved for non-expanded 

160 dimensions and unsharded for expanded dimensions. 

161 """ 

162 # Validate input layout 

163 if not layouts or layouts[0] is None: 

164 raise ValueError( 

165 f"Operation {self.op_name}: expand requires a valid input tensor layout." 

166 ) 

167 input_layout = layouts[0] 

168 in_tensor_map = input_layout.tensor_map 

169 input_ndim = len(in_tensor_map) 

170 

171 # Extract shape information from extra_args 

172 if not extra_args or extra_args[0] is None or len(extra_args[0]) < 2: 

173 raise ValueError( 

174 f"Operation {self.op_name}: expand requires (input_global_shape, target_shape) " 

175 f"in extra_args." 

176 ) 

177 input_global_shape = extra_args[0][0] 

178 target_shape = extra_args[0][1] 

179 

180 if not isinstance(target_shape, (tuple, list)): 

181 raise ValueError( 

182 f"Operation {self.op_name}: target_shape must be tuple/list, got {type(target_shape)}." 

183 ) 

184 if not isinstance(input_global_shape, (tuple, list)): 

185 raise ValueError( 

186 f"Operation {self.op_name}: input_global_shape must be tuple/list, got {type(input_global_shape)}." 

187 ) 

188 

189 target_shape = tuple(target_shape) 

190 input_global_shape = tuple(input_global_shape) 

191 target_ndim = len(target_shape) 

192 

193 # PyTorch rule: target rank cannot be smaller than input rank 

194 if target_ndim < input_ndim: 

195 raise ValueError( 

196 f"Operation {self.op_name}: target shape {target_shape} (ndim={target_ndim}) cannot be " 

197 f"smaller than input shape {input_global_shape} (ndim={input_ndim})." 

198 ) 

199 

200 # Align dimensions (input to target) 

201 num_leading_implicit = target_ndim - input_ndim 

202 aligned_input_shape = (1,) * num_leading_implicit + input_global_shape 

203 aligned_tensor_map = (-1,) * num_leading_implicit + in_tensor_map 

204 

205 # Validate expansion rules and build output tensor_map 

206 output_tensor_map = [] 

207 for i, (in_size, tgt_size, shard_spec) in enumerate( 

208 zip(aligned_input_shape, target_shape, aligned_tensor_map) 

209 ): 

210 if in_size == tgt_size: 

211 # Dimension unchanged - preserve sharding pattern 

212 output_tensor_map.append(shard_spec) 

213 elif in_size == 1 and tgt_size > 1: 

214 # Dimension is expanded (broadcast) - must be unsharded 

215 if shard_spec != -1: 

216 raise ValueError( 

217 f"Operation {self.op_name}: Cannot expand sharded dimension {i} which is going to broadcast " 

218 f"(global size 1 → {tgt_size})." 

219 ) 

220 output_tensor_map.append(-1) 

221 else: 

222 raise ValueError( 

223 f"Operation {self.op_name}: Cannot expand dimension {i} from size {in_size} " 

224 f"to {tgt_size}." 

225 ) 

226 

227 # Construct output layout with same mesh configuration 

228 mesh_shape = input_layout.mesh_shape 

229 alias_name = input_layout.alias_name 

230 rank_list = input_layout.rank_list 

231 

232 # Convert tensor_map indices to alias strings for Layout constructor 

233 def idx_to_alias(idx, aliases): 

234 if idx == -1: 

235 return "None" 

236 return aliases[len(aliases) - idx - 1] 

237 

238 output_map = tuple(idx_to_alias(idx, alias_name) for idx in output_tensor_map) 

239 

240 output_layout = Layout( 

241 mesh_shape=mesh_shape, 

242 alias_name=alias_name, 

243 rank_list=rank_list 

244 ) 

245 output_layout = output_layout(*output_map) 

246 return output_layout