Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

26 lines
725 B

import torch
import torch.nn as nn
from torch.optim import Optimizer
__all__ = ['Precision']
class Precision:
def __init__(self, precision_type: torch.dtype, grad_clipping_type: str, grad_clipping_value: float):
self.precision_type = precision_type
self.grad_clipping_type = grad_clipping_type
self.grad_clipping_value = grad_clipping_value
def setup_model(self, model: nn.Module) -> nn.Module:
# TODO: implement this method
pass
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
# TODO: implement this method
# inject grad clipping and unscale loss
pass
def scale_loss(self, loss: torch.Tensor) -> torch.Tensor:
pass