Coverage for hyper_parallel / platform / mindspore / hsdp / state.py: 97%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore HSDP cell state""" 

16from mindspore.common.api import _no_grad 

17from mindspore import jit_class 

18from hyper_parallel.core.hsdp.hsdp_state import HSDPState 

19from hyper_parallel.platform.mindspore.hsdp.param import MindSporeHSDPParam 

20 

21 

22@jit_class 

23class MindSporeHSDPState(HSDPState): 

24 """MindSpore HSDP cell state""" 

25 

26 def _init_hsdp_params(self): 

27 """init hsdp parameters for cell""" 

28 cells = self.cell.cells_and_names() 

29 for _, sub_cell in cells: 

30 params = sub_cell._params.items() #pylint: disable=W0212 

31 for param_name, param in params: 

32 if hasattr(param, "has_hsdp_param"): 

33 continue 

34 hsdp_param = MindSporeHSDPParam(sub_cell, param_name, param, self.config, self.platform) 

35 param.has_hsdp_param = True 

36 self.hsdp_params.append(hsdp_param) 

37 if hsdp_param.sharded: 

38 self.sharded_hsdp_params.append(hsdp_param) 

39 

40 @_no_grad() 

41 def shard(self): 

42 """change parameters to sharded state""" 

43 super().shard() 

44 

45 @_no_grad() 

46 def unshard(self): 

47 """change parameters to unsharded state""" 

48 super().unshard() 

49 

50 @_no_grad() 

51 def prefetch(self): 

52 """prefetch unsharded parameters""" 

53 super().prefetch() 

54 

55 @_no_grad() 

56 def zero_grads(self): 

57 """zero grad or grad buffer""" 

58 super().zero_grads() 

59 

60 @_no_grad() 

61 def set_grad_ready(self, hsdp_param): 

62 """set grad ready""" 

63 super().set_grad_ready(hsdp_param) 

64 

65 @_no_grad() 

66 def set_requires_grad_sync(self, requires_grad_sync): 

67 """set requires grad sync flag to control gradient sync.""" 

68 super().set_requires_grad_sync(requires_grad_sync)