2022-11-24 08:51:45 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
2023-04-26 08:32:40 +00:00
|
|
|
def run_fwd(model, data, label, criterion) -> torch.Tensor:
|
|
|
|
"""run_fwd
|
|
|
|
run fwd for the model
|
2022-11-29 01:26:06 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (torch.nn.Module): a PyTorch model
|
|
|
|
data (torch.Tensor): input data
|
|
|
|
label (torch.Tensor): label
|
|
|
|
criterion (Optional[Callable]): a function of criterion
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: loss of fwd
|
|
|
|
"""
|
|
|
|
if criterion:
|
|
|
|
y = model(data)
|
|
|
|
y = y.float()
|
|
|
|
loss = criterion(y, label)
|
|
|
|
else:
|
|
|
|
loss = model(data, label)
|
|
|
|
|
|
|
|
loss = loss.float()
|
2023-04-26 08:32:40 +00:00
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
|
|
|
|
"""run_fwd_bwd
|
|
|
|
run fwd and bwd for the model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (torch.nn.Module): a PyTorch model
|
|
|
|
data (torch.Tensor): input data
|
|
|
|
label (torch.Tensor): label
|
|
|
|
criterion (Optional[Callable]): a function of criterion
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: loss of fwd
|
|
|
|
"""
|
|
|
|
loss = run_fwd(model, data, label, criterion)
|
2022-11-30 02:40:31 +00:00
|
|
|
if optimizer:
|
|
|
|
optimizer.backward(loss)
|
2022-11-24 08:51:45 +00:00
|
|
|
else:
|
|
|
|
loss.backward()
|
2022-11-29 01:26:06 +00:00
|
|
|
return loss
|