You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/components_to_test/utils/executor.py

47 lines
1.1 KiB

import torch
def run_fwd(model, data, label, criterion) -> torch.Tensor:
"""run_fwd
run fwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
if criterion:
y = model(data)
y = y.float()
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
return loss
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
Args:
model (torch.nn.Module): a PyTorch model
data (torch.Tensor): input data
label (torch.Tensor): label
criterion (Optional[Callable]): a function of criterion
Returns:
torch.Tensor: loss of fwd
"""
loss = run_fwd(model, data, label, criterion)
if optimizer:
optimizer.backward(loss)
else:
loss.backward()
return loss