Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / shard / custom_shard.py: 18%

50 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2025-2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""shard""" 

16 

17import queue 

18from typing import Callable, Tuple, Optional 

19from hyper_parallel.core.dtensor.layout import DeviceMesh 

20from hyper_parallel.core.dtensor.dtensor import DTensor 

21from hyper_parallel.core.dtensor.placement_types import Placement 

22from hyper_parallel.platform import get_platform 

23platform = get_platform() 

24Tensor = platform.Tensor 

25 

26 

27def custom_shard( 

28 func: Callable, 

29 device_mesh: DeviceMesh, 

30 out_placements: Tuple[Tuple[Placement, ...], ...], 

31 in_placements: Optional[Tuple[Optional[Tuple[Placement, ...]], ...]] = None, 

32 redistribute_inputs: bool = True, 

33) -> Callable: 

34 """ 

35 Wraps a function to handle distributed tensor conversions. 

36 

37 Args: 

38 func (Callable): The function to be wrapped. 

39 device_mesh (DeviceMesh): The device mesh for sharding. 

40 out_placements (Tuple[Tuple[Placement, ...], ...]): Placements for each output tensor. 

41 in_placements (Optional[Tuple[Optional[Tuple[Placement, ...]], ...]], optional): 

42 Placements for each input argument. None entries indicate non-tensor inputs. 

43 redistribute_inputs (bool): Whether to redistribute inputs to required placements. 

44 

45 Returns: 

46 Callable: Wrapped function that handles distributed tensors. 

47 

48 Examples: 

49 >>> mesh = DeviceMesh("npu", (2, 2), mesh_dim_names=("dp", "tp")) 

50 >>> @custom_shard( 

51 ... device_mesh=mesh, 

52 ... out_placements=((Shard(0), Replicate()),), 

53 ... in_placements=((Shard(0), Replicate()), (Replicate(), Shard(1))) 

54 ... ) 

55 ... def my_func(x, y): 

56 ... return x + y 

57 """ 

58 def wrapped(*args, **kwargs): 

59 if in_placements is not None: 

60 if len(in_placements) != len(args): 

61 raise ValueError( 

62 f"in_placements length {len(in_placements)} does not match " 

63 f"the number of input args {len(args)}!" 

64 ) 

65 

66 local_args = [] 

67 contain_distributed_arg = False 

68 

69 args_layout = queue.Queue(len(args)) 

70 for i, arg in enumerate(args): 

71 if isinstance(arg, DTensor): 

72 if in_placements is None: 

73 raise RuntimeError("Found Tensor input but in_placements is None") 

74 

75 required_in_placement = in_placements[i] 

76 if required_in_placement is None: 

77 raise TypeError( 

78 f"Tensor input at position {i} requires Placement, " 

79 "but corresponding in_placements entry is None!" 

80 ) 

81 

82 if redistribute_inputs: 

83 arg = arg.redistribute(device_mesh, required_in_placement) 

84 

85 args_layout.put(arg.layout) 

86 local_tensor = arg.to_local() 

87 local_args.append(local_tensor) 

88 contain_distributed_arg = True 

89 

90 else: 

91 if in_placements is not None and in_placements[i] is not None: 

92 raise TypeError( 

93 f"Non-DTensor input at position {i} requires None in_placements, " 

94 f"but received {in_placements[i]}!" 

95 ) 

96 local_args.append(arg) 

97 

98 out = func(*local_args, **kwargs) 

99 

100 if not contain_distributed_arg: 

101 return out 

102 

103 out_is_tuple = isinstance(out, tuple) 

104 out_tuple = (out,) if not out_is_tuple else out 

105 

106 if len(out_tuple) != len(out_placements): 

107 raise ValueError( 

108 f"Output count {len(out_tuple)} does not match " 

109 f"out_placements count {len(out_placements)}!" 

110 ) 

111 

112 dist_output = [] 

113 for item, out_placement in zip(out_tuple, out_placements): 

114 if isinstance(item, Tensor): 

115 if out_placement is None: 

116 raise TypeError( 

117 "Tensor output requires non-None out_placements!" 

118 ) 

119 dist_output.append( 

120 DTensor.from_local(item, device_mesh=device_mesh, placements=out_placement) 

121 ) 

122 else: 

123 if out_placement is not None: 

124 raise TypeError( 

125 f"Non-tensor output requires None out_placements, got {out_placement}!" 

126 ) 

127 dist_output.append(item) 

128 

129 return dist_output[0] if not out_is_tuple else tuple(dist_output) 

130 

131 return wrapped