"""Initialize MoEMonitorCallback from trainer config."""
super().__init__(trainer)
moe_cfg = getattr(trainer.args, 'moe_monitor', None)
self.enabled = getattr(moe_cfg, 'enabled', False) if moe_cfg else False
self._impl = None
if self.enabled:
from hyper_parallel.core.moe_utils import ( # pylint: disable=C0415
MoEMonitorCallback as _CoreMoEMonitorCallback,
)
from hyper_parallel.core.fully_shard.hsdp_utils import ( # pylint: disable=C0415
GroupInfo,
)
lr = getattr(moe_cfg, 'lr', 1e-3)
num_recomputations = getattr(moe_cfg, 'num_recomputations', 1)
# Resolve DP/TP/CP groups from trainer's device mesh.
dp_group = getattr(self.trainer, '_dp_group_info', None)
tp_group = None
cp_group = None
mesh = getattr(self.trainer, 'mesh', None)
if mesh is not None:
for name, attr_name in [("tp", "tp_group"), ("cp", "cp_group")]:
try:
raw_group = mesh.get_group(name)
group_info = GroupInfo(
group_name=name, group=raw_group,
rank_size=raw_group.size(),
)
if attr_name == "tp_group":
tp_group = group_info
else:
cp_group = group_info
except (KeyError, ValueError, AttributeError):
pass
self._impl = _CoreMoEMonitorCallback(
model=self.trainer.model,
lr=lr,
dp_group=dp_group,
tp_group=tp_group,