Coverage for hyper_parallel / core / hsdp / hsdp_utils.py: 100%
23 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP optimizer shared level"""
16from enum import auto, Enum
19class OptimizerLevel(Enum):
20 """
21 Optimizer level:
22 - SHARD_OPT:
23 Splitting is performed on optimizer state.
24 - SHARD_OPT_GRAD:
25 Splitting is performed on optimizer state, and gradients.
26 - SHARD_OPT_GRAD_PARAM:
27 Splitting is performed on optimizer state, gradients and weights.
28 """
29 SHARD_OPT = auto()
30 SHARD_OPT_GRAD = auto()
31 SHARD_OPT_GRAD_PARAM = auto()
33class GroupInfo:
34 """
35 GroupInfo
36 """
37 def __init__(self, group_name, group, rank_size):
38 self.group_name = group_name
39 self.group = group
40 self.rank_size = rank_size
42class HSDPConfig:
43 """HSDP config"""
45 def __init__(self, shard_size, threshold, requires_acc_grad, grad_scale, shard_level, use_eager_hook,
46 reduce_dtype=None, comm_async=False, comm_fusion=False, bucket_size=-1):
47 """
48 HSDP config init method
49 Args:
50 shard_size: optimizer weight sharded size.
51 threshold: minimum weight size to shard.
52 requires_acc_grad: requires gradient accumulation.
53 grad_scale: use grad_scale to scale grad.
54 shard_level: optimizer shard level.
55 use_eager_hook: use eager hook or graph hook to implement hsdp.
56 reduce_dtype: set gradient reduce dtype.
57 comm_async: use async communication op for grad reduction.
58 comm_fusion: use communication op fusion to reduce the number of communication op.
59 bucket_size: the size of comm fusion buffer.
60 """
61 self.shard_size = shard_size
62 self.threshold = threshold
63 self.requires_acc_grad = requires_acc_grad
64 self.grad_scale = grad_scale
65 self.shard_level = shard_level
66 self.use_eager_hook = use_eager_hook
67 self.reduce_dtype = reduce_dtype
68 self.comm_async = comm_async
69 self.comm_fusion = comm_fusion
70 self.bucket_size = bucket_size
71 self.grad_fusion = comm_fusion and bucket_size != 0