Coverage for hyper_parallel / core / shard / local_func.py: 83%

48 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""shard""" 

16 

17import queue 

18from typing import Callable, Tuple, Optional 

19from hyper_parallel.core.layout import DeviceMesh 

20from hyper_parallel.core.dtensor import DTensor 

21from hyper_parallel.core.placement_types import Placement 

22from hyper_parallel.platform import get_platform 

23platform = get_platform() 

24Tensor = platform.Tensor 

25 

26def custom_shard( 

27 func: Callable, 

28 device_mesh: DeviceMesh, 

29 out_placements: Tuple[Tuple[Placement, ...], ...], 

30 in_placements: Optional[Tuple[Optional[Tuple[Placement, ...]], ...]] = None, 

31 redistribute_inputs: bool = True, 

32) -> Callable: 

33 """ 

34 Wraps a function to handle distributed tensor conversions. 

35 

36 Args: 

37 func (Callable): The function to be wrapped. 

38 device_mesh (DeviceMesh): The device mesh for sharding. 

39 out_placements (Tuple[Tuple[Placement, ...], ...]): Placements for each output tensor. 

40 in_placements (Optional[Tuple[Optional[Tuple[Placement, ...]], ...]], optional): 

41 Placements for each input argument. None entries indicate non-tensor inputs. 

42 redistribute_inputs (bool): Whether to redistribute inputs to required placements. 

43 

44 Returns: 

45 Callable: Wrapped function that handles distributed tensors. 

46 

47 Examples: 

48 >>> mesh = DeviceMesh("npu", (2, 2), nesh_dim_names=("dp", "tp")) 

49 >>> @custom_shard( 

50 ... device_mesh=mesh, 

51 ... out_placements=((Shard(0), Replicate()),), 

52 ... in_placements=((Shard(0), Replicate()), (Replicate(), Shard(1))) 

53 ... ) 

54 ... def my_func(x, y): 

55 ... return x + y 

56 """ 

57 def wrapped(*args, **kwargs): 

58 if in_placements is not None: 

59 assert len(in_placements) == len(args), ( 

60 f"in_placements length {len(in_placements)} does not match " 

61 f"the number of input args {len(args)}!" 

62 ) 

63 

64 local_args = [] 

65 contain_distributed_arg = False 

66 

67 args_layout = queue.Queue(len(args)) 

68 for i, arg in enumerate(args): 

69 if isinstance(arg, DTensor): 

70 if in_placements is None: 

71 raise RuntimeError("Found Tensor input but in_placements is None") 

72 

73 required_in_placement = in_placements[i] 

74 if required_in_placement is None: 

75 raise TypeError( 

76 f"Tensor input at position {i} requires Placement, " 

77 "but corresponding in_placements entry is None!" 

78 ) 

79 

80 if redistribute_inputs: 

81 arg = arg.redistribute(device_mesh, required_in_placement) 

82 

83 args_layout.put(arg.layout) 

84 local_tensor = arg.to_local() 

85 local_args.append(local_tensor) 

86 contain_distributed_arg = True 

87 

88 else: 

89 if in_placements is not None and in_placements[i] is not None: 

90 raise TypeError( 

91 f"Non-DTensor input at position {i} requires None in_placements, " 

92 f"but received {in_placements[i]}!" 

93 ) 

94 local_args.append(arg) 

95 

96 out = func(*local_args, **kwargs) 

97 

98 if not contain_distributed_arg: 

99 return out 

100 

101 out_is_tuple = isinstance(out, tuple) 

102 out_tuple = (out,) if not out_is_tuple else out 

103 

104 assert len(out_tuple) == len(out_placements), ( 

105 f"Output count {len(out_tuple)} does not match " 

106 f"out_placements count {len(out_placements)}!" 

107 ) 

108 

109 dist_output = [] 

110 for item, out_placement in zip(out_tuple, out_placements): 

111 if isinstance(item, Tensor): 

112 if out_placement is None: 

113 raise TypeError( 

114 "Tensor output requires non-None out_placements!" 

115 ) 

116 dist_output.append( 

117 DTensor.from_local(item, device_mesh=device_mesh, placements=out_placement) 

118 ) 

119 else: 

120 if out_placement is not None: 

121 raise TypeError( 

122 f"Non-tensor output requires None out_placements, got {out_placement}!" 

123 ) 

124 dist_output.append(item) 

125 

126 return dist_output[0] if not out_is_tuple else tuple(dist_output) 

127 

128 return wrapped