from typing import Any import torch from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer __all__ = ['GeminiAdamOptimizer'] class GeminiAdamOptimizer(ZeroOptimizer): def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: optimizer = HybridAdam(model.parameters(), **defaults) super().__init__(optimizer, model, **defaults)