mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
16 lines
418 B
16 lines
418 B
2 years ago
|
import torch
|
||
|
|
||
|
|
||
|
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, use_init_ctx=False):
|
||
|
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||
|
if criterion:
|
||
|
y = model(data)
|
||
|
loss = criterion(y, label)
|
||
|
else:
|
||
|
loss = model(data, label)
|
||
|
loss = loss.float()
|
||
|
if use_init_ctx:
|
||
|
model.backward(loss)
|
||
|
else:
|
||
|
loss.backward()
|