mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
16 lines
421 B
16 lines
421 B
2 years ago
|
from typing import Any
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from colossalai.nn.optimizer import HybridAdam
|
||
|
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||
|
|
||
|
__all__ = ['GeminiAdamOptimizer']
|
||
|
|
||
|
|
||
|
class GeminiAdamOptimizer(ZeroOptimizer):
|
||
|
|
||
|
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
|
||
|
optimizer = HybridAdam(model.parameters(), **defaults)
|
||
|
super().__init__(optimizer, model, **defaults)
|