Coverage for hyper_parallel / platform / torch / hsdp / state.py: 93%

14 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Torch HSDP cell state""" 

16from hyper_parallel.core.hsdp.hsdp_state import HSDPState 

17from hyper_parallel.platform.torch.hsdp.param import TorchHSDPParam 

18from hyper_parallel.platform.torch.platform import TorchPlatform 

19 

20class TorchHSDPState(HSDPState): 

21 """Torch HSDP cell state""" 

22 

23 def _init_hsdp_params(self): 

24 """init hsdp parameters for cell""" 

25 params = self.cell.named_parameters() 

26 for param_name, param in params: 

27 if hasattr(param, "has_hsdp_param"): 

28 continue 

29 hsdp_param = TorchHSDPParam(self.cell, param_name, param, self.config, self.platform) 

30 param.has_hsdp_param = True 

31 self.hsdp_params.append(hsdp_param) 

32 if hsdp_param.sharded: 

33 self.sharded_hsdp_params.append(hsdp_param)