Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_expand.py: 92%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Expand operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22 

23class ExpandDistributedOp(DistributedOp): 

24 """Distributed implementation for torch.Tensor.expand.""" 

25 

26 def infer_layout(self, layouts, extra_args=None): 

27 """ 

28 Infer output layout for torch.Tensor.expand. 

29 

30 PyTorch semantics: 

31 - Expands singleton dimensions (size 1) to larger sizes 

32 - Passing -1 preserves the original size of that dimension 

33 - Only dimensions with global size 1 can be expanded 

34 - Existing dimensions being expanded MUST be unsharded: 

35 

36 Args: 

37 layouts (tuple): Layouts of inputs. Expected: 

38 layouts[0] (Layout): Input tensor layout (required). 

39 extra_args (tuple): Should contain 'sizes'. Expected: 

40 extra_args[0] (int): One element in desired expanded sizes (required). 

41 ... 

42 extra_args[n] (int): One element in desired expanded sizes (required). 

43 

44 Returns: 

45 Layout: Output tensor layout with: 

46 - New dimensions: unsharded (-1) 

47 - Expanded existing dimensions: unsharded (-1) 

48 - Preserved dimensions (-1 in sizes): original sharding preserved 

49 """ 

50 if not layouts or layouts[0] is None: 

51 raise ValueError( 

52 f"Operation {self.op_name}: expand requires a valid input tensor layout." 

53 ) 

54 input_layout = layouts[0] 

55 in_tensor_map = input_layout.tensor_map 

56 input_ndim = len(in_tensor_map) 

57 

58 if not extra_args or len(extra_args) < 1: 

59 raise ValueError( 

60 f"Operation {self.op_name}: expand requires 'sizes' parameter in extra_args." 

61 ) 

62 output_ndim = len(extra_args) 

63 

64 # Normalize sizes to tuple 

65 sizes = [] 

66 for i in range(output_ndim): 

67 if not isinstance(extra_args[i], int): 

68 raise ValueError( 

69 f"Operation {self.op_name}: elements in 'sizes' parameter must be int." 

70 ) 

71 sizes.append(extra_args[i]) 

72 sizes = tuple(sizes) 

73 

74 # output_ndim = len(sizes) 

75 num_new_dims = output_ndim - input_ndim 

76 

77 # PyTorch only allows prepending new dimensions (not inserting in middle) 

78 if num_new_dims < 0: 

79 raise ValueError( 

80 f"Operation {self.op_name}: Cannot reduce dimensions with expand. " 

81 f"Input has {input_ndim} dims, requested {output_ndim} dims." 

82 ) 

83 

84 # Build output tensor map 

85 output_map = [] 

86 

87 # Rule 1: For the new dimensions, the size cannot be set to -1. 

88 for i in range(num_new_dims): 

89 if sizes[i] == -1: 

90 raise ValueError( 

91 f"Operation {self.op_name}: Cannot use -1 for new dimension at position {i}. " 

92 ) 

93 output_map.append(-1) # Always unsharded 

94 

95 # Rule 2: Process existing dimensions 

96 for i in range(input_ndim): 

97 output_dim_idx = num_new_dims + i 

98 requested_size = sizes[output_dim_idx] 

99 

100 if requested_size == -1: 

101 # keep original sharding 

102 output_map.append(in_tensor_map[i]) 

103 else: 

104 # Cannot expand dimension which is sharded 

105 if in_tensor_map[i] != -1: 

106 raise ValueError( 

107 f"Operation {self.op_name}: Cannot expand dimension {i} which is sharded." 

108 ) 

109 # Expanded dimension becomes unsharded in output 

110 output_map.append(-1) 

111 

112 # Construct output layout 

113 mesh_shape = input_layout.mesh_shape 

114 alias_name = input_layout.alias_name 

115 rank_list = input_layout.rank_list 

116 

117 def idx_to_alias(idx, aliases): 

118 if idx == -1: 

119 return "None" 

120 return aliases[len(aliases) - idx - 1] 

121 

122 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_map) 

123 

124 output_layout = Layout( 

125 mesh_shape=mesh_shape, 

126 alias_name=alias_name, 

127 rank_list=rank_list 

128 ) 

129 output_layout = output_layout(*output_alias_map) 

130 return output_layout 

131 

132 

133class ExpandAsDistributedOp(DistributedOp): 

134 """Distributed implementation for torch.Tensor.expand_as.""" 

135 

136 def infer_layout(self, layouts, extra_args=None): 

137 """ 

138 Infer output layout for expand_as. 

139 

140 PyTorch semantics: 

141 - Only dimensions with global size == 1 can be expanded to larger sizes 

142 - Dimensions with size > 1 must exactly match between input and target 

143 - Broadcast replicates a single value across the expanded dimension 

144 

145 Critical sharding constraints: 

146 - Input dimensions with global size == 1 MUST be unsharded (-1) 

147 - When expanding a dimension (size 1 → N), this dimension must be unsharded in input layout 

148 - Expanded dimensions become unsharded in output 

149 - Non-expanded dimensions preserve their input sharding pattern 

150 

151 Args: 

152 layouts (tuple): Layouts of inputs. Expected: 

153 layouts[0] (Layout): Input tensor layout (required). 

154 layouts[1] (Layout): Target tensor layout (No need). 

155 extra_args (tuple): Must contain shape information. Expected: 

156 extra_args[0][0] (tuple of int): Input global shape. 

157 extra_args[0][1] (tuple of int): Target global shape. 

158  

159 Returns: 

160 Layout: Output tensor layout with sharding preserved for non-expanded 

161 dimensions and unsharded for expanded dimensions. 

162 """ 

163 # Validate input layout 

164 if not layouts or layouts[0] is None: 

165 raise ValueError( 

166 f"Operation {self.op_name}: expand requires a valid input tensor layout." 

167 ) 

168 input_layout = layouts[0] 

169 in_tensor_map = input_layout.tensor_map 

170 input_ndim = len(in_tensor_map) 

171 

172 # Extract shape information from extra_args 

173 if not extra_args or extra_args[0] is None or len(extra_args[0]) < 2: 

174 raise ValueError( 

175 f"Operation {self.op_name}: expand requires (input_global_shape, target_shape) " 

176 f"in extra_args." 

177 ) 

178 input_global_shape = extra_args[0][0] 

179 target_shape = extra_args[0][1] 

180 

181 if not isinstance(target_shape, (tuple, list)): 

182 raise ValueError( 

183 f"Operation {self.op_name}: target_shape must be tuple/list, got {type(target_shape)}." 

184 ) 

185 if not isinstance(input_global_shape, (tuple, list)): 

186 raise ValueError( 

187 f"Operation {self.op_name}: input_global_shape must be tuple/list, got {type(input_global_shape)}." 

188 ) 

189 

190 target_shape = tuple(target_shape) 

191 input_global_shape = tuple(input_global_shape) 

192 target_ndim = len(target_shape) 

193 

194 # PyTorch rule: target rank cannot be smaller than input rank 

195 if target_ndim < input_ndim: 

196 raise ValueError( 

197 f"Operation {self.op_name}: target shape {target_shape} (ndim={target_ndim}) cannot be " 

198 f"smaller than input shape {input_global_shape} (ndim={input_ndim})." 

199 ) 

200 

201 # Align dimensions (input to target) 

202 num_leading_implicit = target_ndim - input_ndim 

203 aligned_input_shape = (1,) * num_leading_implicit + input_global_shape 

204 aligned_tensor_map = (-1,) * num_leading_implicit + in_tensor_map 

205 

206 # Validate expansion rules and build output tensor_map 

207 output_tensor_map = [] 

208 for i, (in_size, tgt_size, shard_spec) in enumerate( 

209 zip(aligned_input_shape, target_shape, aligned_tensor_map) 

210 ): 

211 if in_size == tgt_size: 

212 # Dimension unchanged - preserve sharding pattern 

213 output_tensor_map.append(shard_spec) 

214 elif in_size == 1 and tgt_size > 1: 

215 # Dimension is expanded (broadcast) - must be unsharded 

216 if shard_spec != -1: 

217 raise ValueError( 

218 f"Operation {self.op_name}: Cannot expand sharded dimension {i} which is going to broadcast " 

219 f"(global size 1 → {tgt_size})." 

220 ) 

221 output_tensor_map.append(-1) 

222 else: 

223 raise ValueError( 

224 f"Operation {self.op_name}: Cannot expand dimension {i} from size {in_size} " 

225 f"to {tgt_size}." 

226 ) 

227 

228 # Construct output layout with same mesh configuration 

229 mesh_shape = input_layout.mesh_shape 

230 alias_name = input_layout.alias_name 

231 rank_list = input_layout.rank_list 

232 

233 # Convert tensor_map indices to alias strings for Layout constructor 

234 def idx_to_alias(idx, aliases): 

235 if idx == -1: 

236 return "None" 

237 return aliases[len(aliases) - idx - 1] 

238 

239 output_map = tuple(idx_to_alias(idx, alias_name) for idx in output_tensor_map) 

240 

241 output_layout = Layout( 

242 mesh_shape=mesh_shape, 

243 alias_name=alias_name, 

244 rank_list=rank_list 

245 ) 

246 output_layout = output_layout(*output_map) 

247 return output_layout