Coverage for hyper_parallel / core / hsdp / hsdp_utils.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP optimizer shared level""" 

16from enum import auto, Enum 

17 

18 

19class OptimizerLevel(Enum): 

20 """ 

21 Optimizer level: 

22 - SHARD_OPT: 

23 Splitting is performed on optimizer state. 

24 - SHARD_OPT_GRAD: 

25 Splitting is performed on optimizer state, and gradients. 

26 - SHARD_OPT_GRAD_PARAM: 

27 Splitting is performed on optimizer state, gradients and weights. 

28 """ 

29 SHARD_OPT = auto() 

30 SHARD_OPT_GRAD = auto() 

31 SHARD_OPT_GRAD_PARAM = auto() 

32 

33class GroupInfo: 

34 """ 

35 GroupInfo 

36 """ 

37 def __init__(self, group_name, group, rank_size): 

38 self.group_name = group_name 

39 self.group = group 

40 self.rank_size = rank_size 

41 

42class HSDPConfig: 

43 """HSDP config""" 

44 

45 def __init__(self, shard_size, threshold, requires_acc_grad, grad_scale, shard_level, use_eager_hook, 

46 reduce_dtype=None, comm_async=False, comm_fusion=False, bucket_size=-1): 

47 """ 

48 HSDP config init method 

49 Args: 

50 shard_size: optimizer weight sharded size. 

51 threshold: minimum weight size to shard. 

52 requires_acc_grad: requires gradient accumulation. 

53 grad_scale: use grad_scale to scale grad. 

54 shard_level: optimizer shard level. 

55 use_eager_hook: use eager hook or graph hook to implement hsdp. 

56 reduce_dtype: set gradient reduce dtype. 

57 comm_async: use async communication op for grad reduction. 

58 comm_fusion: use communication op fusion to reduce the number of communication op. 

59 bucket_size: the size of comm fusion buffer. 

60 """ 

61 self.shard_size = shard_size 

62 self.threshold = threshold 

63 self.requires_acc_grad = requires_acc_grad 

64 self.grad_scale = grad_scale 

65 self.shard_level = shard_level 

66 self.use_eager_hook = use_eager_hook 

67 self.reduce_dtype = reduce_dtype 

68 self.comm_async = comm_async 

69 self.comm_fusion = comm_fusion 

70 self.bucket_size = bucket_size 

71 self.grad_fusion = comm_fusion and bucket_size != 0