Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / ops / parallel_stack.py: 98%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15""" 

16Distributed implementation for Stack operator. 

17""" 

18 

19from hyper_parallel.core.dtensor.layout import Layout 

20from .parallel_ops import DistributedOp 

21 

22# pylint: disable=unused-argument 

23def _normalize_stack_args(tensors, dim=0, **kwargs): 

24 """ 

25 Normalize arguments for torch.stack. 

26 """ 

27 return (tensors,), {'dim': dim} 

28 

29 

30class StackDistributedOp(DistributedOp): 

31 """Distributed implementation for Stack operator.""" 

32 

33 def preprocess(self, args, kwargs): 

34 """ 

35 Preprocess input arguments and extract local components. 

36 

37 Normalizes args, explicitly extracts parameters, and prepares 

38 local tensors and cache values without validation logic. 

39 """ 

40 args, kwargs = _normalize_stack_args(*args, **kwargs) 

41 

42 # Explicit parameter extraction 

43 tensors = args[0] 

44 dim = kwargs['dim'] 

45 

46 # Extract local tensors and layouts 

47 local_tensors = tuple(t.to_local() if hasattr(t, "to_local") else t for t in tensors) 

48 layouts = [getattr(t, "layout", None) for t in tensors] 

49 

50 # Construct local args and kwargs for the inner op execution 

51 local_args = (local_tensors,) 

52 local_kwargs = {'dim': dim} 

53 

54 # Flatten layouts and append dim for caching and inference 

55 cache_values = layouts + [dim] 

56 

57 return local_args, local_kwargs, cache_values 

58 

59 def infer_layout(self, cache_values): 

60 """ 

61 Infer output layout based on cache values for torch.stack. 

62 

63 All validation logic (e.g., empty checks, layout consistency, 

64 and dimension bounds) is handled here. 

65 """ 

66 layouts = cache_values[:-1] 

67 dim = cache_values[-1] 

68 

69 # 1. Validation Logic 

70 if not layouts: 

71 raise ValueError(f"Operation {self.op_name}: stack requires at least one input tensor.") 

72 

73 valid_layouts = [lyt for lyt in layouts if lyt is not None] 

74 

75 if not valid_layouts: 

76 raise ValueError(f"Operation {self.op_name}: stack requires at least one input DTensor.") 

77 

78 # Reference layout to validate consistency across all input tensors 

79 base_layout = valid_layouts[0] 

80 for layout in valid_layouts[1:]: 

81 if layout != base_layout: 

82 raise ValueError( 

83 f"Operation {self.op_name}: All input tensors must have the same layout. " 

84 f"Expected layout: {base_layout}, Mismatched layout: {layout}" 

85 ) 

86 

87 ndim = len(base_layout.tensor_map) 

88 

89 # Normalize and validate the dimension. For stack, valid range is [-ndim - 1, ndim] 

90 actual_dim = dim if dim >= 0 else dim + ndim + 1 

91 if actual_dim < 0 or actual_dim > ndim: 

92 raise ValueError( 

93 f"Operation {self.op_name}: Dimension out of range (expected to be in range of " 

94 f"[{-ndim - 1}, {ndim}], but got {dim})" 

95 ) 

96 

97 # 2. Layout Inference Logic 

98 in_tensor_map = base_layout.tensor_map 

99 

100 # Insert an unsharded mapping (-1) at the newly created dimension 

101 output_tensor_map = list(in_tensor_map) 

102 output_tensor_map.insert(actual_dim, -1) 

103 

104 mesh_shape = base_layout.mesh_shape 

105 alias_name = base_layout.alias_name 

106 rank_list = base_layout.rank_list 

107 

108 def idx_to_alias(idx, aliases): 

109 """Map tensor_map index back to the alias string.""" 

110 if idx == -1: 

111 return "None" 

112 # Reverse indexing mapped to the framework's layout design 

113 return aliases[len(aliases) - idx - 1] 

114 

115 output_alias_map = tuple(idx_to_alias(idx, alias_name) for idx in output_tensor_map) 

116 

117 # Reconstruct the output layout 

118 output_layout = Layout( 

119 mesh_shape=mesh_shape, 

120 alias_name=alias_name, 

121 rank_list=rank_list 

122 ) 

123 

124 # Apply the placement mappings via the __call__ method 

125 output_layout = output_layout(*output_alias_map) 

126 

127 # Returns a tuple of output layouts and None for extra_args 

128 return ((output_layout,), None)