Coverage for / home / jenkins / .local / lib / python3.10 / site-packages / hyper_parallel / core / utils / clip_grad.py: 67%
6 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-11 07:26 +0800
1# Copyright 2026 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Distributed-aware gradient clipping utilities."""
16from typing import Optional
18from hyper_parallel.platform import get_platform
20__all__: list[str] = ["clip_grad_norm_"]
23def clip_grad_norm_(
24 parameters,
25 max_norm: float,
26 norm_type: float = 2.0,
27 error_if_nonfinite: bool = False,
28 foreach: Optional[bool] = None,
29):
30 r"""Distributed-aware gradient norm clipping.
32 Drop-in replacement for the standard ``clip_grad_norm_`` that
33 correctly handles sharded parameters by deriving communication from
34 each parameter's DTensor spec (``device_mesh`` + ``placements``).
36 Supports FSDP, HSDP, TP+FSDP, and any parallelism expressible via
37 DTensor placements. Plain (non-DTensor) parameters are treated as
38 replicated and require no communication.
40 Args:
41 parameters: An ``nn.Module``, a single ``Tensor``, or an iterable
42 of ``Tensor`` s whose gradients to clip. When an ``nn.Module``
43 is given, ``module.parameters()`` is used.
44 max_norm (float): max norm of the gradients.
45 norm_type (float): type of the used p-norm. Can be ``'inf'``
46 for infinity norm. Default: 2.0.
47 error_if_nonfinite (bool): if ``True``, an error is thrown if
48 the total norm is ``nan``, ``inf``, or ``-inf``.
49 Default: ``False``.
50 foreach (bool, optional): Unused, accepted for API compatibility
51 with the standard ``clip_grad_norm_``.
53 Returns:
54 Total norm of the parameter gradients (viewed as a single vector).
55 """
56 platform = get_platform()
57 return platform.clip_grad_norm_(
58 parameters, max_norm, norm_type,
59 error_if_nonfinite=error_if_nonfinite, foreach=foreach,
60 )