Coverage for hyper_parallel / platform / torch / hsdp / param.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""HSDP parameter""" 

16import torch 

17import torch.distributed as dist 

18from hyper_parallel.core.hsdp.hsdp_param import HSDPParam 

19from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel 

20 

21 

22class TorchHSDPParam(HSDPParam): 

23 """ 

24 Torch HSDP parameter. 

25 """ 

26 def _init_sharded_param(self): 

27 """add and init sharded param""" 

28 slice_index = self.hsdp_rank % self.shard_size 

29 local_param = self.platform.get_param_local_data(self.param) 

30 param_slice = torch.chunk(local_param, self.shard_size, 0)[slice_index] + 0 # avoid error when handle view tensor 

31 self.platform.update_param_data(self.param, param_slice) 

32 self.sharded_param = param_slice 

33 

34 def _init_unsharded_param(self): 

35 """ 

36 Init unsharded param only at non-parameter shard level 

37 """ 

38 if self.config.shard_level != OptimizerLevel.SHARD_OPT_GRAD_PARAM: 

39 self.unsharded_param = torch.empty(self.param_shape, dtype=self.param.dtype, device=self.param.device) 

40 else: 

41 self.unsharded_param = None 

42 

43 def _get_unsharded_param_data(self, async_op): 

44 """get unsharded param data with async comm""" 

45 local_param = self.platform.get_param_local_data(self.param) 

46 if self.unsharded_param is not None: 

47 unsharded_param = self.unsharded_param 

48 else: 

49 unsharded_param = torch.empty(self.param_shape, dtype=self.param.dtype, device=self.param.device) 

50 handle = dist.all_gather_into_tensor(unsharded_param, local_param, group=self.sharded_group_info.group, 

51 async_op=async_op) 

52 return unsharded_param, handle