Coverage for hyper_parallel / core / layout.py: 91%

336 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""layout""" 

16 

17import copy 

18import functools 

19import numpy as np 

20 

21from hyper_parallel.core.placement_types import Placement, Shard, Replicate, Partial 

22from hyper_parallel.core.device_mesh import DeviceMesh, _create_device_mesh 

23from hyper_parallel.platform import get_platform 

24 

25platform = get_platform() 

26 

27 

28def _infer_slice_area_by_rank(mesh_shape, tensor_map, rank_id: int, full_shape: tuple): # -> tuple[tuple[int]]: 

29 """Return the range of each axis from full tensor for slice in current rank.""" 

30 

31 def _get_dev_num_alone_dim(mesh_shape, dim): 

32 """_get_dev_num_alone_dim.""" 

33 return mesh_shape[-dim - 1] if dim != -1 else 1 

34 

35 def _rank_id_to_dev_id_list(mesh_shape, rank_id): 

36 """Infer dev id list by rank_id and mesh_shape""" 

37 dims = len(mesh_shape) 

38 dev_id_list = [0] * dims 

39 for i in range(dims - 1, -1, -1): 

40 dev_id_list[i] = rank_id % mesh_shape[i] 

41 rank_id = rank_id // mesh_shape[i] 

42 return dev_id_list 

43 

44 dev_id_list = _rank_id_to_dev_id_list(mesh_shape, rank_id) 

45 

46 dims = len(full_shape) 

47 area = [] 

48 for axis in range(dims): 

49 mapping = tensor_map[axis] 

50 if isinstance(mapping, int): 

51 mapping = (mapping,) 

52 split_num = 1 

53 for dim in mapping: 

54 split_num *= _get_dev_num_alone_dim(mesh_shape, dim) 

55 

56 slice_id = 0 

57 coef = 1 

58 for dim in reversed(mapping): 

59 if dim == -1: 

60 continue 

61 slice_id += dev_id_list[-dim - 1] * coef 

62 coef *= _get_dev_num_alone_dim(mesh_shape, dim) 

63 slice_size = full_shape[axis] // split_num 

64 start = slice_id * slice_size 

65 end = start + slice_size 

66 area.append((start, end)) 

67 return area 

68 

69 

70def _get_slice_tensor_by_layout(global_tensor, layout): 

71 """Transfer global tensor to local tensor by layout""" 

72 inner_rank_id = layout.rank_list.index(layout.mesh.rank) 

73 slice_area = _infer_slice_area_by_rank(layout.mesh_shape, layout.tensor_map, inner_rank_id, global_tensor.shape) 

74 

75 def get_slice_data(full_data, offset): 

76 area = () 

77 for begin, end in offset: 

78 area += (slice(begin, end),) 

79 return full_data[area] 

80 

81 local_tensor = get_slice_data(global_tensor, slice_area) 

82 return local_tensor 

83 

84 

85def _infer_slice_shape_by_layout(global_shape, layout): 

86 """Infer slice shape from global_shape and layout""" 

87 slice_shape = list(global_shape) 

88 alias_tensor_map = layout.alias_tensor_map 

89 for i in range(len(global_shape)): 

90 axis_name = alias_tensor_map[i] 

91 if isinstance(axis_name, str): 

92 axis_name = (axis_name,) 

93 for sub_axis_name in axis_name: 

94 if sub_axis_name != "None": 

95 slice_shape[i] = slice_shape[i] // layout.mesh.get_device_num_along_axis(sub_axis_name) 

96 return slice_shape 

97 

98 

99class Layout: 

100 """ 

101 Topological abstraction describing cluster devices for tensor slice placement on the cluster. 

102 

103 Note: 

104 - It is valid only in semi auto parallel or auto parallel mode. 

105 - The multiplication result of the `mesh_shape` must be equal to the device count in a pipeline stage. 

106 - When the layout function is invoked to constructs a sharding strategy, each alias name is only allowed to be 

107 used once to shard a tensor. 

108 

109 Args: 

110 mesh_shape (tuple): Describe the shape of devices arrangement, its element type is int. 

111 alias_name (tuple): The alias name for each axis of mesh_shape, its length shoits element type is string. 

112 When using "interleaved_parallel" as an alias name, the tensor would be split into multiple 

113 copies on the corresponding partition dimension on a single card. 

114 rank_list (tuple, optional): Data is allocated to the device according to rank_list. Default: ``None``. 

115 

116 Raises: 

117 TypeError: `mesh_shape` is not a tuple type. 

118 TypeError: `alias_name` is not a tuple type. 

119 TypeError: 'rank_list' is not a list type. 

120 ValueError: `mesh_shape` length is not equal to `alias_name` length. 

121 TypeError: The element of `mesh_shape` is not int type. 

122 TypeError: The element of `alias_name` is not a str type. 

123 TypeError: The element of `rank_list` is not int type. 

124 ValueError: The element of `alias_name` is an empty str. 

125 ValueError: The element of `alias_name` is "None". 

126 ValueError: `alias_name` contains repeated element. 

127 

128 Supported Platforms: 

129 ``Ascend`` 

130 

131 Examples: 

132 >>> from mindspore.parallel import Layout 

133 >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp")) 

134 >>> layout0 = layout("dp", "mp") 

135 >>> print(layout0.to_dict()) 

136 {"mesh_shape": (2, 2, 2), "tensor_map": (2, 0), "interleaved_parallel": False, 

137 'alias_name': {'dp', 'sp', 'mp'}, "rank_list": [0, 1, 2, 3, 4, 5, 6, 7]} 

138 >>> layout = Layout((2, 2, 2), ("dp", "sp", "interleaved_parallel")) 

139 >>> layout1 = layout(("dp", "interleaved_parallel"), "sp") 

140 """ 

141 

142 def __init__(self, mesh_shape, alias_name, rank_list=None, init_backend=True): 

143 self._alias_name = alias_name 

144 self._tensor_map = None 

145 if not rank_list: 

146 self._rank_list = tuple(range(np.prod(np.array(mesh_shape)))) 

147 else: 

148 self._rank_list = tuple(rank_list) 

149 self._partial = [None] * len(mesh_shape) # partial status for each dev dim 

150 self._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None] 

151 self._alias_tensor_map = None 

152 self._mesh = _create_device_mesh("npu", mesh_shape, mesh_dim_names=alias_name, rank_list=self._rank_list, 

153 init_backend=init_backend) 

154 self._compact_str = self._to_compact_string() 

155 self._placements = None 

156 

157 @classmethod 

158 def from_device_mesh(cls, device_mesh: DeviceMesh) -> 'Layout': 

159 """ 

160 Create a Layout from an existing DeviceMesh. 

161 

162 Args: 

163 device_mesh (DeviceMesh): The device mesh to create layout from. 

164 

165 Returns: 

166 Layout: A new Layout instance initialized with the properties of the provided device mesh. 

167 

168 Examples: 

169 >>> from hyper_parallel.core.layout import Layout, DeviceMesh 

170 >>> device_mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "mp")) 

171 >>> layout = Layout.from_device_mesh(device_mesh) 

172 """ 

173 obj = cls.__new__(cls) 

174 obj._mesh = device_mesh 

175 obj._alias_name = device_mesh.mesh_dim_names 

176 obj._rank_list = device_mesh.rank_list 

177 obj._tensor_map = None 

178 obj._partial = [None] * len(device_mesh.mesh_shape) 

179 obj._support_partial_op = ['sum', 'max', 'min', 'avg', 'prod', 'all', None] 

180 obj._alias_tensor_map = None 

181 obj._placements = None 

182 obj._compact_str = obj._to_compact_string() 

183 return obj 

184 

185 def __call__(self, *alias_tensor_map): 

186 obj = copy.deepcopy(self) 

187 

188 if len(alias_tensor_map) == 1 and isinstance(alias_tensor_map[0], (list, tuple)): 

189 if len(alias_tensor_map[0]) > 0 and isinstance(alias_tensor_map[0][0], Placement): 

190 return self._process_placement_layout(obj, alias_tensor_map[0]) 

191 

192 if len(alias_tensor_map) > 0 and isinstance(alias_tensor_map[0], Placement): 

193 return self._process_placement_layout(obj, alias_tensor_map) 

194 

195 return self._process_alias_layout(obj, alias_tensor_map) 

196 

197 def _process_placement_layout(self, obj, placements): 

198 """Process layout defined by Placement types.""" 

199 obj.set_placements(placements) 

200 return copy.deepcopy(obj) 

201 

202 def _process_alias_layout(self, obj, alias_tensor_map): 

203 """Process layout defined by alias strings.""" 

204 obj.set_alias_tensor_map(alias_tensor_map) 

205 tensor_map = () 

206 writed_map = () 

207 for ele in alias_tensor_map: 

208 if isinstance(ele, tuple): 

209 ele_map = () 

210 for item in ele: 

211 if item == "None": 

212 ele_map += (-1,) 

213 continue 

214 if item not in obj.alias_name: 

215 raise ValueError(f'The axis {item} is not found in {obj.alias_name}') 

216 if item in writed_map: 

217 raise ValueError(f'The axis {item} has been set more than one in {obj.alias_name}') 

218 ele_map += (len(obj.alias_name) - 1 - obj.alias_name.index(item),) 

219 writed_map += (item,) 

220 tensor_map += (ele_map,) 

221 continue 

222 if ele == "None": 

223 tensor_map += (-1,) 

224 continue 

225 if ele not in obj.alias_name: 

226 raise ValueError(f'The axis {ele} is not found in {obj.alias_name}') 

227 if ele in writed_map: 

228 raise ValueError(f'The axis {ele} has been set more than one in {obj.alias_name}') 

229 tensor_map += (len(obj.alias_name) - 1 - obj.alias_name.index(ele),) 

230 writed_map += (ele,) 

231 obj.set_tensor_map(tensor_map) 

232 obj.tensor_map_to_placement() 

233 obj.update_compact_str() 

234 return copy.deepcopy(obj) 

235 

236 def to_dict(self): 

237 """ 

238 Transform layout to a dictionary. 

239 """ 

240 if self._mesh.mesh_shape is None: 

241 raise ValueError("The device_shape of layout is None") 

242 if self._tensor_map is None: 

243 raise ValueError("The tensor_map of layout is None") 

244 interleaved_parallel = "interleaved_parallel" in self._mesh.mesh_dim_names 

245 return {"mesh_shape": self._mesh.mesh_shape, "tensor_map": self._tensor_map, 

246 "interleaved_parallel": interleaved_parallel, "alias_name": self._mesh.mesh_dim_names, 

247 "rank_list": self._rank_list} 

248 

249 def placement_to_tensor_map(self, dim): 

250 """ 

251 Transform placement to tensor map. 

252 

253 This method converts the `placements` configuration (consisting of Shard, Replicate, Partial) 

254 into a `tensor_map` representation used for distributed tensor operations. 

255 

256 Args: 

257 dim (int): The dimension of the tensor. Must be a positive integer. 

258 

259 Returns: 

260 tuple: A tuple representing the tensor map, where each element corresponds to a tensor dimension. 

261 A value of -1 indicates the dimension is not sharded, while other values indicate 

262 the mesh dimension index along which the tensor dimension is sharded. 

263 

264 Raises: 

265 ValueError: If `dim` is negative. 

266 ValueError: If a shard dimension in `placements` is out of bounds for the given tensor dimension. 

267 ValueError: If a tensor dimension is sharded by multiple mesh axes. 

268 """ 

269 if dim < 0: 

270 raise ValueError(f"Tensor dimension must be positive, but got {dim}") 

271 if dim == 0: 

272 return self._handle_zero_dim_placement() 

273 

274 dim_map = self._build_dim_map_from_placements(dim) 

275 tensor_map = self._convert_dim_map_to_tensor_map(dim_map) 

276 self.set_tensor_map(tuple(tensor_map)) 

277 self._alias_tensor_map = self._build_readable_tensor_map() 

278 self.update_compact_str() 

279 return tensor_map 

280 

281 def _handle_zero_dim_placement(self): 

282 """Handle the special case of zero-dimensional tensor.""" 

283 self.set_tensor_map(()) 

284 self._alias_tensor_map = () 

285 for mesh_idx, placement in enumerate(self.placements): 

286 if isinstance(placement, Partial): 

287 self._partial[mesh_idx] = self._extract_reduce_op(placement) 

288 return [] 

289 

290 def _build_dim_map_from_placements(self, dim): 

291 """Build dimension map from placements.""" 

292 dim_map = [-1] * dim 

293 self.partial_ops = {} 

294 for mesh_idx, placement in enumerate(self.placements): 

295 if isinstance(placement, Shard): 

296 shard_dim = placement.dim 

297 if shard_dim < -dim or shard_dim >= dim: 

298 raise ValueError(f"Shard dimension {shard_dim} is out of bounds for tensor of dimension {dim}") 

299 if shard_dim < 0: 

300 shard_dim += dim 

301 if dim_map[shard_dim] != -1: 

302 raise ValueError(f"Dimension {shard_dim} has been sharded by Mesh axis {dim_map[shard_dim]}") 

303 dim_map[shard_dim] = mesh_idx 

304 elif isinstance(placement, Partial): 

305 self._partial[mesh_idx] = self._extract_reduce_op(placement) 

306 return dim_map 

307 

308 def _extract_reduce_op(self, placement): 

309 """Extract reduce operation name from Partial placement.""" 

310 op_name = getattr(placement, "reduce_op", "sum") 

311 if isinstance(op_name, str): 

312 op_name = op_name.lower() 

313 return op_name 

314 

315 def _convert_dim_map_to_tensor_map(self, dim_map): 

316 """Convert dimension map to tensor map format.""" 

317 device_dim_count = len(self.mesh_shape) 

318 return [ 

319 device_dim_count - 1 - mesh_idx if mesh_idx != -1 else -1 

320 for mesh_idx in dim_map 

321 ] 

322 

323 def _build_readable_tensor_map(self): 

324 """Build human-readable alias tensor map from tensor_map.""" 

325 readable_map = [] 

326 for item in self._tensor_map: 

327 if self._mesh.mesh_dim_names is None: 

328 readable_map.append("None") 

329 elif isinstance(item, tuple): 

330 mapped_tuple = tuple( 

331 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - dim] if dim != -1 else "None" 

332 for dim in item 

333 ) 

334 readable_map.append(mapped_tuple) 

335 else: 

336 readable_map.append( 

337 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - item] if item != -1 else "None" 

338 ) 

339 return tuple(readable_map) 

340 

341 def tensor_map_to_placement(self): 

342 """ 

343 Transform tensor map to placement. 

344 

345 This method converts the existing `tensor_map` and `partial` status into a list of `Placement` objects 

346 (Shard, Replicate, Partial). This is the inverse operation of `placement_to_tensor_map`. 

347 

348 Returns: 

349 list[Placement]: A list of Placement objects describing the distribution strategy for each 

350 dimension of the device mesh. 

351 

352 Raises: 

353 ValueError: If `tensor_map` is not configured (None). 

354 """ 

355 if self._tensor_map is None: 

356 raise ValueError("The tensor_map is None, cannot transform to placements.") 

357 mesh_ndim = len(self.mesh_shape) 

358 placements = [Replicate()] * mesh_ndim 

359 for tensor_dim, mapping in enumerate(self._tensor_map): 

360 mapping_list = mapping if isinstance(mapping, tuple) else (mapping,) 

361 for map_val in mapping_list: 

362 if map_val != -1: 

363 root_mesh_idx = mesh_ndim - 1 - map_val 

364 placements[root_mesh_idx] = Shard(dim=tensor_dim) 

365 for mesh_idx, op in enumerate(self.partial): 

366 if op is not None: 

367 placements[mesh_idx] = Partial(reduce_op=op) 

368 self.set_placements(placements) 

369 return placements 

370 

371 def __setstate__(self, state): 

372 self.__dict__.update(state) 

373 self.update_mesh() 

374 

375 @property 

376 def mesh(self): 

377 return self._mesh 

378 

379 def update_mesh(self): 

380 self._mesh = _create_device_mesh("npu", self.mesh_shape, mesh_dim_names=self.alias_name, 

381 rank_list=self.rank_list) 

382 

383 @property 

384 def rank_list(self): 

385 """rank list""" 

386 return self._rank_list 

387 

388 @rank_list.setter 

389 def rank_list(self, val): 

390 self._rank_list = val 

391 

392 @property 

393 def mesh_shape(self): 

394 """mesh shape""" 

395 return self._mesh.mesh_shape 

396 

397 @property 

398 def alias_name(self): 

399 """alias name""" 

400 return self._mesh.mesh_dim_names 

401 

402 @property 

403 def alias_tensor_map(self): 

404 return self._alias_tensor_map 

405 

406 def set_alias_tensor_map(self, alias_tensor_map): 

407 """Set alias_tensor_map""" 

408 self._alias_tensor_map = alias_tensor_map 

409 

410 @property 

411 def placements(self): 

412 """placements""" 

413 return self._placements 

414 

415 def set_placements(self, placements): 

416 """Set placements.""" 

417 self._placements = placements 

418 

419 @property 

420 def tensor_map(self): 

421 """tensor map""" 

422 return self._tensor_map 

423 

424 def set_tensor_map(self, tensor_map): 

425 """Set tensor_map.""" 

426 self._tensor_map = tensor_map 

427 

428 @property 

429 def partial(self): 

430 """partial status""" 

431 return self._partial 

432 

433 def set_partial_by_dev_axis(self, axis, op): 

434 """Set the partial status for the specified dev ID, means pending to do reduce by op.""" 

435 if op not in self._support_partial_op: 

436 raise ValueError(f"Partial op must be one of {self._support_partial_op}, but got {op}") 

437 if self.is_dev_axis_apply_shard(axis): 

438 raise ValueError("Partial dim must be replicate.") 

439 self._partial[self._mesh.axis_index(axis)] = op 

440 self.tensor_map_to_placement() 

441 self.update_compact_str() 

442 

443 def get_partial_by_dev_id(self, axis): 

444 """Get the partial status for the specified dev id""" 

445 return self.partial[self._mesh.axis_index(axis)] 

446 

447 def is_dev_axis_apply_shard(self, axis): 

448 """Return true if device axis is applying shard""" 

449 axis_id = self._mesh.axis_id(axis) 

450 

451 def flatten(input_x): 

452 flatten_res = [] 

453 for item in input_x: 

454 if isinstance(item, tuple): 

455 flatten_res.extend(flatten(item)) 

456 else: 

457 flatten_res.append(item) 

458 return flatten_res 

459 

460 flatten_tensor_map = flatten(self.tensor_map) 

461 return axis_id in flatten_tensor_map 

462 

463 def get_dev_axis_apply_shard_axis(self, axis): 

464 """Return the axis which be split by axis. If axis not be apply to shard, return None.""" 

465 for dim, dim_map in enumerate(self.alias_tensor_map): 

466 if (isinstance(dim_map, tuple) and axis in dim_map) or axis == dim_map: 

467 return dim 

468 return None 

469 

470 def reset_partial(self): 

471 self._partial = [None] * len(self.mesh_shape) 

472 self.tensor_map_to_placement() 

473 self.update_compact_str() 

474 

475 def is_partial(self): 

476 """Return true if any dim in mesh_shape is partial""" 

477 return any(self.partial) 

478 

479 def get_global_shape(self, slice_shape): 

480 """get global shape""" 

481 return self._mesh.get_global_shape(slice_shape, self._tensor_map) 

482 

483 def get_devices_for_axis(self, axis, rank): 

484 """ 

485 Get the repeat rank list when the axis is not shard. 

486 

487 Args: 

488 layout (Layout): Layout 

489 axis (str): Axis name. 

490 rank (int): Global rank 

491 

492 Returns: 

493 list: reduce rank list 

494 """ 

495 return self._mesh.get_devices_for_axis(axis, rank) 

496 

497 def get_comm_group_by_axis(self, axis): 

498 return self._mesh.get_comm_group_by_axis(axis) 

499 

500 def repeat_num(self): 

501 """ 

502 Number of repeated placements. 

503 In pipeline parallel, only the last stage return repeat num, other stages return -1. 

504 For example: 

505 layout = Layout((2, 4), ("dp", "mp")) 

506 x_layout = layout("dp", "None") 

507 The repeat_num is equal to all device num 8 divided by device num corresponding to used axis 2, that is 4. 

508 """ 

509 if self._tensor_map is None: 

510 raise ValueError(f"The tensor_map is None, the mesh_shape is {self._mesh.mesh_shape}," 

511 f" alias_name is {self._mesh.mesh_dim_names}") 

512 

513 # if it is not the last stage, return -1 

514 group_size = platform.get_world_size() 

515 if self._rank_list[-1] != (group_size - 1): 

516 return -1 

517 

518 all_device_num = functools.reduce(lambda x, y: x * y, self._mesh.mesh_shape) 

519 used_dev_num = 1 

520 for ele in self._tensor_map: 

521 if isinstance(ele, tuple): 

522 for item in ele: 

523 if item >= 0: 

524 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - item - 1] 

525 continue 

526 if ele >= 0: 

527 used_dev_num *= self._mesh.mesh_shape[len(self._mesh.mesh_shape) - ele - 1] 

528 

529 return all_device_num // used_dev_num 

530 

531 def _to_compact_string(self): 

532 """ 

533 generate dict key 

534 

535 Returns: 

536 str: string for compact 

537 """ 

538 mesh_key = self._mesh.to_hash() 

539 hash_key = (self._tensor_map, self.partial) 

540 hash_key += mesh_key 

541 return str(hash_key) 

542 

543 @property 

544 def compact_str(self): 

545 return self._compact_str 

546 

547 def update_compact_str(self): 

548 self._compact_str = self._to_compact_string() 

549 

550 def to_string(self): 

551 """ 

552 layout dump 

553 

554 Returns: 

555 str: layout string 

556 """ 

557 device_info = f"Mesh shape: {self._mesh.mesh_shape}" 

558 alias_info = f"Alias Names: {self._mesh.mesh_dim_names}" 

559 rank_info = f"Rank List: {self._rank_list}" 

560 partial_info = f"Partial: {self.partial}" 

561 

562 if self._tensor_map is None: 

563 tensor_info = "Tensor Map: Not configured" 

564 else: 

565 readable_map = [] 

566 for item in self._tensor_map: 

567 if isinstance(item, tuple): 

568 # 处理嵌套元组 

569 mapped_tuple = tuple( 

570 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - dim] if dim != -1 else "None" 

571 for dim in item 

572 ) 

573 readable_map.append(mapped_tuple) 

574 else: 

575 readable_map.append( 

576 self._mesh.mesh_dim_names[len(self._mesh.mesh_dim_names) - 1 - item] if item != -1 else "None" 

577 ) 

578 

579 tensor_info = f"Tensor Map: {tuple(readable_map)}" 

580 

581 interleaved = "Yes" if "interleaved_parallel" in self._mesh.mesh_dim_names else "No" 

582 interleaved_info = f"Interleaved Parallel: {interleaved}" 

583 

584 return ( 

585 f"Layout Configuration:\n" 

586 f" {device_info}\n" 

587 f" {alias_info}\n" 

588 f" {partial_info}\n" 

589 f" {tensor_info}\n" 

590 f" {interleaved_info}\n" 

591 f" {rank_info}" 

592 ) 

593 

594 def __str__(self): 

595 """__str__""" 

596 return self.to_string() 

597 

598 def __repr__(self): 

599 """__repr__""" 

600 return f"<Layout at {hex(id(self))}>" 

601 

602 def __eq__(self, other): 

603 """ 

604 __eq__ 

605 """ 

606 if not isinstance(other, Layout): 

607 return False 

608 

609 if (self.mesh_shape != other.mesh_shape or 

610 self.alias_name != other.alias_name or 

611 self.partial != other.partial or 

612 self.rank_list != other.rank_list): 

613 return False 

614 

615 if self._tensor_map is None or other.tensor_map is None: 

616 return self._tensor_map is other.tensor_map 

617 return self._tensor_map == other.tensor_map