Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

58 lines
1.5 KiB

from typing import Callable, Dict, Optional, Union
import torch
from torch.nn import Module
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
def run_fwd(
model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None
) -> torch.Tensor:
"""run_fwd
run fwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
outputs = model(**data)
outputs = output_transform_fn(outputs)
if criterion:
loss = criterion(outputs)
else:
loss = next(iter(outputs.values())).sum()
return loss
def run_fwd_bwd(
model: Module,
data: Dict,
output_transform_fn: Callable,
criterion: Optional[Callable] = None,
optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None,
) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
loss = run_fwd(model, data, output_transform_fn, criterion)
if optimizer:
optimizer.backward(loss)
else:
loss.backward()
return loss