Coverage for hyper_parallel / platform / mindspore / hsdp / param.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-03-01 07:33 +0800

1# Copyright 2025 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""MindSpore HSDP parameter""" 

16from mindspore import ops, jit_class 

17from mindspore.common.parameter import Parameter 

18from mindspore.common.initializer import initializer 

19from hyper_parallel.core.dtensor import DTensor 

20from hyper_parallel.core.hsdp.hsdp_param import HSDPParam 

21 

22 

23@jit_class 

24class MindSporeHSDPParam(HSDPParam): 

25 """ 

26 MindSpore HSDP parameter. 

27 """ 

28 def _init_sharded_param(self): 

29 """add and init sharded param""" 

30 if not self.param.has_init: 

31 slice_index = self.hsdp_rank % self.shard_size 

32 local_param = self.platform.get_param_local_data(self.param) 

33 param_slice = ops.split(local_param, local_param.shape[0] // self.shard_size)[slice_index] + 0 

34 self.platform.update_param_data(self.param, param_slice) 

35 

36 # TODO: refactor, mindspore jit ast does not need this 

37 self.sharded_param = Parameter(param_slice, 

38 name="sharded_"+self.param.name, 

39 requires_grad=self.param.requires_grad) 

40 else: 

41 dp_slice_index = self.hsdp_rank % self.shard_size 

42 data_slice_index = self.tp_rank * self.shard_size + dp_slice_index 

43 if isinstance(self.param.init_mode, DTensor): 

44 init_mode_local_shape = self.param.init_mode.local_shape 

45 else: 

46 init_mode_local_shape = self.param.init_mode.shape 

47 init_shape = list(init_mode_local_shape) 

48 init_shape[0] = init_shape[0] // self.shard_size 

49 if isinstance(self.param.init_mode, DTensor): 

50 # 'self.param.to_local()' and 'self.param.init_mode.to_local()' is same object. 

51 self.param.init_mode.to_local().shape = init_shape 

52 else: 

53 # 'self.param' and 'self.param.init_mode' is not same object. set 'self.param.shape' manually. 

54 self.param.init_mode.shape = init_shape 

55 self.param.shape = init_shape 

56 self.param.hsdp_init_index = data_slice_index 

57 

58 # TODO: refactor, mindspore jit ast does not need this 

59 self.sharded_param = Parameter(initializer("zeros", init_shape, self.param.dtype), 

60 name="sharded_"+self.param.name, 

61 requires_grad=self.param.requires_grad) 

62 

63 def _init_unsharded_param(self): 

64 return