Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_flatten.py: 94%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Flatten operator. 

17""" 

18from hyper_parallel.core.shard.ops.parallel_reshape import ReshapeDistributedOp 

19 

20 

21class FlattenDistributedOp(ReshapeDistributedOp): 

22 """Distributed implementation for torch.flatten.""" 

23 

24 def infer_layout(self, layouts, extra_args=None): 

25 """ 

26 Infer output layout for torch.flatten. 

27 

28 PyTorch semantics: 

29 - flatten(input, start_dim=0, end_dim=-1) 

30 - Flattens the input tensor starting from `start_dim` to `end_dim`. 

31 

32 Args: 

33 layouts (tuple): Layouts of inputs. 

34 extra_args (tuple): Contains scalar arguments (start_dim, end_dim) 

35 and the input shapes list appended by WithShape suffix. 

36 

37 Returns: 

38 Layout: Output tensor layout. 

39 """ 

40 if not layouts or layouts[0] is None: 

41 raise ValueError( 

42 f"Operation {self.op_name}: flatten requires a valid input tensor layout." 

43 ) 

44 

45 input_layout = layouts[0] 

46 # WithShape suffix appends a list of input_shapes as the last item in extra_args 

47 input_shapes = extra_args[-1] 

48 input_shape = input_shapes[0] 

49 

50 # PyTorch flatten defaults: start_dim=0, end_dim=-1 

51 start_dim = 0 

52 end_dim = -1 

53 

54 # Parse scalar args (start_dim, end_dim) from extra_args 

55 num_scalar_args = len(extra_args) - 1 

56 if num_scalar_args > 0: 

57 start_dim = extra_args[0] 

58 if num_scalar_args > 1: 

59 end_dim = extra_args[1] 

60 

61 ndim = len(input_shape) 

62 

63 # Handle 0-D tensor case specifically to avoid parent class indexing errors 

64 if ndim == 0: 

65 out_layout = layouts[0].__class__.from_device_mesh(input_layout.mesh) 

66 out_layout.set_placements(input_layout.placements) 

67 out_layout.placement_to_tensor_map(1) # Flattened 0-D becomes 1-D of shape (1,) 

68 return out_layout 

69 

70 # Handle negative dimensions 

71 if start_dim < 0: 

72 start_dim += ndim 

73 if end_dim < 0: 

74 end_dim += ndim 

75 

76 # Validate dimension bounds 

77 if start_dim < 0 or start_dim >= ndim or end_dim < 0 or end_dim >= ndim: 

78 raise ValueError(f"Dimension out of range for flatten: start_dim={start_dim}, end_dim={end_dim}") 

79 

80 

81 # If start_dim > end_dim, PyTorch returns the tensor unchanged (identity operation). 

82 # This also covers the case where start_dim == end_dim (e.g., flatten(1, 1)) 

83 # because the loop `range(start_dim, end_dim + 1)` would still execute, but 

84 # effectively for a single dimension, resulting in `dst_shape == input_shape`. 

85 # However, the underlying `_merge_unshared_axis` logic in ReshapeDistributedOp 

86 # is not designed for an identity mapping and introduces inconsistencies. 

87 # Thus, for an identity flatten, we should simply return the original layout. 

88 if start_dim >= end_dim: 

89 # If flattening a single dimension or a range that results in no actual flattening, 

90 # the layout should remain unchanged. 

91 return input_layout # Directly return the input layout 

92 

93 

94 # Calculate the target flattened shape 

95 flattened_size = 1 

96 for i in range(start_dim, end_dim + 1): 

97 flattened_size *= input_shape[i] 

98 dst_shape = list(input_shape[:start_dim]) + [flattened_size] + list(input_shape[end_dim + 1:]) 

99 

100 # Prepare args for ReshapeDistributedOp.infer_layout 

101 # Since self.op_name is 'flatten' (not in ["reshape", "view"]), 

102 # the parent class will treat this as the MindSpore Reshape branch, 

103 # which expects: extra_args = [dst_shape, input_shape] 

104 mock_extra_args = [dst_shape, input_shape] 

105 

106 # Reuse parent class logic to safely resolve Shard tensor mapping merges 

107 out_layout, _ = super().infer_layout(layouts, mock_extra_args) 

108 

109 # Returning purely the Layout object, which is what the WithShape suffix expects 

110 return out_layout