Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_flatten.py: 94%
36 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Flatten operator.
17"""
18from hyper_parallel.core.shard.ops.parallel_reshape import ReshapeDistributedOp
21class FlattenDistributedOp(ReshapeDistributedOp):
22 """Distributed implementation for torch.flatten."""
24 def infer_layout(self, layouts, extra_args=None):
25 """
26 Infer output layout for torch.flatten.
28 PyTorch semantics:
29 - flatten(input, start_dim=0, end_dim=-1)
30 - Flattens the input tensor starting from `start_dim` to `end_dim`.
32 Args:
33 layouts (tuple): Layouts of inputs.
34 extra_args (tuple): Contains scalar arguments (start_dim, end_dim)
35 and the input shapes list appended by WithShape suffix.
37 Returns:
38 Layout: Output tensor layout.
39 """
40 if not layouts or layouts[0] is None:
41 raise ValueError(
42 f"Operation {self.op_name}: flatten requires a valid input tensor layout."
43 )
45 input_layout = layouts[0]
46 # WithShape suffix appends a list of input_shapes as the last item in extra_args
47 input_shapes = extra_args[-1]
48 input_shape = input_shapes[0]
50 # PyTorch flatten defaults: start_dim=0, end_dim=-1
51 start_dim = 0
52 end_dim = -1
54 # Parse scalar args (start_dim, end_dim) from extra_args
55 num_scalar_args = len(extra_args) - 1
56 if num_scalar_args > 0:
57 start_dim = extra_args[0]
58 if num_scalar_args > 1:
59 end_dim = extra_args[1]
61 ndim = len(input_shape)
63 # Handle 0-D tensor case specifically to avoid parent class indexing errors
64 if ndim == 0:
65 out_layout = layouts[0].__class__.from_device_mesh(input_layout.mesh)
66 out_layout.set_placements(input_layout.placements)
67 out_layout.placement_to_tensor_map(1) # Flattened 0-D becomes 1-D of shape (1,)
68 return out_layout
70 # Handle negative dimensions
71 if start_dim < 0:
72 start_dim += ndim
73 if end_dim < 0:
74 end_dim += ndim
76 # Validate dimension bounds
77 if start_dim < 0 or start_dim >= ndim or end_dim < 0 or end_dim >= ndim:
78 raise ValueError(f"Dimension out of range for flatten: start_dim={start_dim}, end_dim={end_dim}")
81 # If start_dim > end_dim, PyTorch returns the tensor unchanged (identity operation).
82 # This also covers the case where start_dim == end_dim (e.g., flatten(1, 1))
83 # because the loop `range(start_dim, end_dim + 1)` would still execute, but
84 # effectively for a single dimension, resulting in `dst_shape == input_shape`.
85 # However, the underlying `_merge_unshared_axis` logic in ReshapeDistributedOp
86 # is not designed for an identity mapping and introduces inconsistencies.
87 # Thus, for an identity flatten, we should simply return the original layout.
88 if start_dim >= end_dim:
89 # If flattening a single dimension or a range that results in no actual flattening,
90 # the layout should remain unchanged.
91 return input_layout # Directly return the input layout
94 # Calculate the target flattened shape
95 flattened_size = 1
96 for i in range(start_dim, end_dim + 1):
97 flattened_size *= input_shape[i]
98 dst_shape = list(input_shape[:start_dim]) + [flattened_size] + list(input_shape[end_dim + 1:])
100 # Prepare args for ReshapeDistributedOp.infer_layout
101 # Since self.op_name is 'flatten' (not in ["reshape", "view"]),
102 # the parent class will treat this as the MindSpore Reshape branch,
103 # which expects: extra_args = [dst_shape, input_shape]
104 mock_extra_args = [dst_shape, input_shape]
106 # Reuse parent class logic to safely resolve Shard tensor mapping merges
107 out_layout, _ = super().infer_layout(layouts, mock_extra_args)
109 # Returning purely the Layout object, which is what the WithShape suffix expects
110 return out_layout