Coverage for hyper_parallel / core / fully_shard / hsdp_param.py: 55%
60 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP parameter"""
18class HSDPParamV2:
19 """
20 HSDP parameter.
21 """
23 def __init__(
24 self,
25 param,
26 module_info,
27 mesh_info,
28 post_forward_mesh_info,
29 shard_placement_fn,
30 mp_policy,
31 offload_policy,
32 threshold,
33 ):
34 raise NotImplementedError("HSDP param subclasses must implement __init__")
36 def _init_sharded_param(self, param, shard_placement_fn):
37 """add and init sharded param"""
38 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_param")
40 def _init_sharded_post_forward_param_metadata(self, param):
41 raise NotImplementedError("HSDP param subclasses must implement _init_sharded_post_forward_param_metadata")
43 def init_dtype_attrs(self, mp_policy):
44 raise NotImplementedError("HSDP param subclasses must implement init_dtype_attrs")
46 def _init_extensions(self):
47 raise NotImplementedError("HSDP param subclasses must implement _init_extensions")
49 def init_all_gather_outputs(self, all_gather_input_numels, all_gather_input_dtypes, world_size, device, force_recreate=False):
50 raise NotImplementedError("HSDP param subclasses must implement init_all_gather_outputs")
52 def init_unsharded_param(self):
53 raise NotImplementedError("HSDP param subclasses must implement init_unsharded_param")
55 def to_sharded(self):
56 raise NotImplementedError("HSDP param subclasses must implement to_sharded")
58 def to_sharded_post_forward(self):
59 raise NotImplementedError("HSDP param subclasses must implement to_sharded_post_forward")
61 def to_unsharded(self):
62 raise NotImplementedError("HSDP param subclasses must implement to_unsharded")
64 def to_sharded_dtensor(self, tensor):
65 raise NotImplementedError("HSDP param subclasses must implement to_sharded_dtensor")
67 def to_sharded_post_forward_dtensor(self, tensor):
68 raise NotImplementedError("HSDP param subclasses must implement to_sharded_post_forward_dtensor")
70 def to_accumulated_grad_if_needed(self):
71 raise NotImplementedError("HSDP param subclasses must implement to_accumulated_grad_if_needed")
73 def accumulate_unsharded_grad_if_needed(self):
74 raise NotImplementedError("HSDP param subclasses must implement accumulate_unsharded_grad_if_needed")
76 def alloc_all_gather_outputs(self):
77 raise NotImplementedError("HSDP param subclasses must implement alloc_all_gather_outputs")
79 def free_unsharded_param(self):
80 raise NotImplementedError("HSDP param subclasses must implement free_unsharded_param")
82 @property
83 def all_gather_inputs(self):
84 raise NotImplementedError("HSDP param subclasses must implement all_gather_inputs")
86 @property
87 def unsharded_param(self):
88 raise NotImplementedError("HSDP param subclasses must implement unsharded_param")
90 @property
91 def unsharded_grad_data(self):
92 raise NotImplementedError("HSDP param subclasses must implement unsharded_grad_data")
94 @property
95 def unsharded_accumulated_grad_data(self):
96 raise NotImplementedError("HSDP param subclasses must implement unsharded_accumulated_grad_data")
98 @property
99 def _sharded_local_tensor(self):
100 raise NotImplementedError("HSDP param subclasses must implement _sharded_local_tensor")
102 def _get_unsharded_param_data(self, async_op=False):
103 raise NotImplementedError("HSDP param subclasses must implement _get_unsharded_param_data")
105 def unshard(self, async_op=False):
106 raise NotImplementedError("HSDP param subclasses must implement unshard")
108 def wait_for_unshard(self):
109 raise NotImplementedError("HSDP param subclasses must implement wait_for_unshard")
111 def shard(self):
112 raise NotImplementedError("HSDP param subclasses must implement shard")
114 def reduce_scatter_grad(self):
115 raise NotImplementedError("HSDP param subclasses must implement reduce_scatter_grad")
117 def all_reduce_grad(self):
118 raise NotImplementedError("HSDP param subclasses must implement all_reduce_grad")