mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.5 KiB
59 lines
1.5 KiB
1 year ago
|
from typing import Callable, Dict, Optional, Union
|
||
|
|
||
2 years ago
|
import torch
|
||
1 year ago
|
from torch.nn import Module
|
||
|
from torch.optim import Optimizer
|
||
|
|
||
|
from colossalai.interface import OptimizerWrapper
|
||
2 years ago
|
|
||
|
|
||
1 year ago
|
def run_fwd(
|
||
|
model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None
|
||
|
) -> torch.Tensor:
|
||
2 years ago
|
"""run_fwd
|
||
|
run fwd for the model
|
||
2 years ago
|
|
||
|
Args:
|
||
|
model (torch.nn.Module): a PyTorch model
|
||
|
data (torch.Tensor): input data
|
||
|
label (torch.Tensor): label
|
||
|
criterion (Optional[Callable]): a function of criterion
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: loss of fwd
|
||
|
"""
|
||
1 year ago
|
outputs = model(**data)
|
||
|
outputs = output_transform_fn(outputs)
|
||
2 years ago
|
if criterion:
|
||
1 year ago
|
loss = criterion(outputs)
|
||
2 years ago
|
else:
|
||
1 year ago
|
loss = next(iter(outputs.values())).sum()
|
||
2 years ago
|
return loss
|
||
|
|
||
|
|
||
1 year ago
|
def run_fwd_bwd(
|
||
|
model: Module,
|
||
|
data: Dict,
|
||
|
output_transform_fn: Callable,
|
||
|
criterion: Optional[Callable] = None,
|
||
|
optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None,
|
||
|
) -> torch.Tensor:
|
||
2 years ago
|
"""run_fwd_bwd
|
||
|
run fwd and bwd for the model
|
||
|
|
||
|
Args:
|
||
|
model (torch.nn.Module): a PyTorch model
|
||
|
data (torch.Tensor): input data
|
||
|
label (torch.Tensor): label
|
||
|
criterion (Optional[Callable]): a function of criterion
|
||
|
|
||
|
Returns:
|
||
|
torch.Tensor: loss of fwd
|
||
|
"""
|
||
1 year ago
|
loss = run_fwd(model, data, output_transform_fn, criterion)
|
||
2 years ago
|
if optimizer:
|
||
|
optimizer.backward(loss)
|
||
2 years ago
|
else:
|
||
|
loss.backward()
|
||
2 years ago
|
return loss
|