Coverage for hyper_parallel / platform / torch / hsdp / param.py: 100%
22 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""HSDP parameter"""
16import torch
17import torch.distributed as dist
18from hyper_parallel.core.hsdp.hsdp_param import HSDPParam
19from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel
22class TorchHSDPParam(HSDPParam):
23 """
24 Torch HSDP parameter.
25 """
26 def _init_sharded_param(self):
27 """add and init sharded param"""
28 slice_index = self.hsdp_rank % self.shard_size
29 local_param = self.platform.get_param_local_data(self.param)
30 param_slice = torch.chunk(local_param, self.shard_size, 0)[slice_index] + 0 # avoid error when handle view tensor
31 self.platform.update_param_data(self.param, param_slice)
32 self.sharded_param = param_slice
34 def _init_unsharded_param(self):
35 """
36 Init unsharded param only at non-parameter shard level
37 """
38 if self.config.shard_level != OptimizerLevel.SHARD_OPT_GRAD_PARAM:
39 self.unsharded_param = torch.empty(self.param_shape, dtype=self.param.dtype, device=self.param.device)
40 else:
41 self.unsharded_param = None
43 def _get_unsharded_param_data(self, async_op):
44 """get unsharded param data with async comm"""
45 local_param = self.platform.get_param_local_data(self.param)
46 if self.unsharded_param is not None:
47 unsharded_param = self.unsharded_param
48 else:
49 unsharded_param = torch.empty(self.param_shape, dtype=self.param.dtype, device=self.param.device)
50 handle = dist.all_gather_into_tensor(unsharded_param, local_param, group=self.sharded_group_info.group,
51 async_op=async_op)
52 return unsharded_param, handle