Coverage for  / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / utils / clip_grad.py: 67%

6 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-11 07:26 +0800

1# Copyright 2026 Huawei Technologies Co., Ltd 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================ 

15"""Distributed-aware gradient clipping utilities.""" 

16from typing import Optional 

17 

18from hyper_parallel.platform import get_platform 

19 

20__all__: list[str] = ["clip_grad_norm_"] 

21 

22 

23def clip_grad_norm_( 

24 parameters, 

25 max_norm: float, 

26 norm_type: float = 2.0, 

27 error_if_nonfinite: bool = False, 

28 foreach: Optional[bool] = None, 

29): 

30 r"""Distributed-aware gradient norm clipping. 

31 

32 Drop-in replacement for the standard ``clip_grad_norm_`` that 

33 correctly handles sharded parameters by deriving communication from 

34 each parameter's DTensor spec (``device_mesh`` + ``placements``). 

35 

36 Supports FSDP, HSDP, TP+FSDP, and any parallelism expressible via 

37 DTensor placements. Plain (non-DTensor) parameters are treated as 

38 replicated and require no communication. 

39 

40 Args: 

41 parameters: An ``nn.Module``, a single ``Tensor``, or an iterable 

42 of ``Tensor`` s whose gradients to clip. When an ``nn.Module`` 

43 is given, ``module.parameters()`` is used. 

44 max_norm (float): max norm of the gradients. 

45 norm_type (float): type of the used p-norm. Can be ``'inf'`` 

46 for infinity norm. Default: 2.0. 

47 error_if_nonfinite (bool): if ``True``, an error is thrown if 

48 the total norm is ``nan``, ``inf``, or ``-inf``. 

49 Default: ``False``. 

50 foreach (bool, optional): Unused, accepted for API compatibility 

51 with the standard ``clip_grad_norm_``. 

52 

53 Returns: 

54 Total norm of the parameter gradients (viewed as a single vector). 

55 """ 

56 platform = get_platform() 

57 return platform.clip_grad_norm_( 

58 parameters, max_norm, norm_type, 

59 error_if_nonfinite=error_if_nonfinite, foreach=foreach, 

60 )