Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_cumsum.py: 97%
30 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Distributed implementation for Cumsum operator.
17"""
19from hyper_parallel.core.dtensor.layout import Layout
20from .parallel_ops import DistributedOp
23class CumsumDistributedOp(DistributedOp):
24 """Distributed implementation for torch.cumsum."""
26 def infer_layout(self, layouts, extra_args=None):
27 """
28 Infer output layout for torch.cumsum
30 PyTorch semantics:
31 - Computes cumulative sum along dimension `dim`
32 - Output shape is identical to input shape
33 - Operation is sequential along `dim`: each element depends on all preceding elements in that dimension
35 Critical sharding constraint:
36 - The dimension `dim` MUST be unsharded (-1 in tensor_map)
38 Args:
39 layouts (tuple): Layouts of inputs. Expected:
40 layouts[0] (Layout): Input tensor layout (required).
41 extra_args (tuple): Should contain 'dim'. Expected:
42 extra_args[0] (int): Dimension to perform cumsum over (required).
44 Returns:
45 Layout: Output tensor layout (identical to input layout after validation).
46 """
47 if not layouts or layouts[0] is None:
48 raise ValueError(
49 f"Operation {self.op_name}: cumsum requires a valid input tensor layout."
50 )
51 input_layout = layouts[0]
52 in_tensor_map = input_layout.tensor_map
53 input_ndim = len(in_tensor_map)
55 if not extra_args or extra_args[0] is None:
56 raise ValueError(
57 f"Operation {self.op_name}: cumsum requires 'dim' parameter in extra_args."
58 )
59 dim = extra_args[0]
60 if not isinstance(dim, int):
61 raise ValueError(
62 f"Operation {self.op_name}: 'dim' must be an integer, got {type(dim)}."
63 )
65 # Normalize negative dimensions
66 normalized_dim = dim if dim >= 0 else dim + input_ndim
67 if normalized_dim < 0 or normalized_dim >= input_ndim:
68 raise ValueError(
69 f"Operation {self.op_name}: Dimension {dim} out of range for "
70 f"{input_ndim}-dimensional input tensor."
71 )
73 # cumsum dimension must be unsharded
74 if in_tensor_map[normalized_dim] != -1:
75 raise ValueError(
76 f"Operation {self.op_name}: Cannot perform sharding on normalized dimension {normalized_dim}, "
77 f"but found sharding assignment: {in_tensor_map[normalized_dim]}"
78 )
80 mesh_shape = input_layout.mesh_shape
81 alias_name = input_layout.alias_name
82 rank_list = input_layout.rank_list
84 # Create output layout
85 def idx_to_alias(idx, aliases):
86 if idx == -1:
87 return "None"
88 return aliases[len(aliases) - idx - 1]
89 output_map = tuple(idx_to_alias(idx, alias_name) for idx in in_tensor_map)
91 output_layout = Layout(
92 mesh_shape=mesh_shape,
93 alias_name=alias_name,
94 rank_list=rank_list
95 )
96 output_layout = output_layout(*output_map)
97 return output_layout