mirror of https://github.com/InternLM/InternLM
47 lines
1.1 KiB
Python
47 lines
1.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
class BaseOptimizer(Optimizer):
|
|
"""
|
|
Base Optimizer.
|
|
"""
|
|
|
|
def __init__(self, optim: Optimizer): # pylint: disable=W0231
|
|
self.optim = optim
|
|
|
|
@property
|
|
def param_groups(self):
|
|
return self.optim.param_groups
|
|
|
|
@property
|
|
def defaults(self):
|
|
return self.optim.defaults
|
|
|
|
def add_param_group(self, *args, **kwargs):
|
|
return self.optim.add_param_group(*args, **kwargs)
|
|
|
|
def step(self, *args, **kwargs):
|
|
return self.optim.step(*args, **kwargs)
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
self.optim.zero_grad(*args, **kwargs)
|
|
|
|
def load_state_dict(self, *args, **kwargs):
|
|
self.optim.load_state_dict(*args, **kwargs)
|
|
|
|
def state_dict(self):
|
|
return self.optim.state_dict()
|
|
|
|
def backward(self, loss):
|
|
loss.backward()
|
|
|
|
def backward_by_grad(self, tensor, grad):
|
|
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
|
|
|
def clip_grad_norm(self):
|
|
pass
|