Coverage for hyper_parallel / platform / torch / hsdp / state.py: 93%
14 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-03-01 07:33 +0800
1# Copyright 2025 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Torch HSDP cell state"""
16from hyper_parallel.core.hsdp.hsdp_state import HSDPState
17from hyper_parallel.platform.torch.hsdp.param import TorchHSDPParam
18from hyper_parallel.platform.torch.platform import TorchPlatform
20class TorchHSDPState(HSDPState):
21 """Torch HSDP cell state"""
23 def _init_hsdp_params(self):
24 """init hsdp parameters for cell"""
25 params = self.cell.named_parameters()
26 for param_name, param in params:
27 if hasattr(param, "has_hsdp_param"):
28 continue
29 hsdp_param = TorchHSDPParam(self.cell, param_name, param, self.config, self.platform)
30 param.has_hsdp_param = True
31 self.hsdp_params.append(hsdp_param)
32 if hsdp_param.sharded:
33 self.sharded_hsdp_params.append(hsdp_param)