mirror of https://github.com/hpcaitech/ColossalAI
adapting bert unitest interface
parent
7977422aeb
commit
4d94cd513e
|
@ -48,9 +48,21 @@ def get_training_components():
|
|||
num_hidden_layers=num_layer,
|
||||
)
|
||||
print('building BertForSequenceClassification model')
|
||||
model = BertForSequenceClassification(config)
|
||||
|
||||
# adapting huggingface BertForSequenceClassification for single unitest calling interface
|
||||
class ModelAaptor(BertForSequenceClassification):
|
||||
|
||||
def forward(self, input_ids, labels):
|
||||
"""
|
||||
inputs: data, label
|
||||
outputs: loss
|
||||
"""
|
||||
return super().forward(input_ids=input_ids, labels=labels)[0]
|
||||
|
||||
model = ModelAaptor(config)
|
||||
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
return model
|
||||
|
||||
trainloader = get_bert_data_loader(batch_size=2,
|
||||
|
|
|
@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
|||
loss.backward()
|
||||
|
||||
|
||||
def run_bert_fwd_bwd(model, data, label, enable_autocast=False):
|
||||
# with no criterion
|
||||
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
output = model(input_ids=data, labels=label)
|
||||
loss = output[0]
|
||||
loss = model(data, label)
|
||||
if isinstance(model, ShardedModelV2):
|
||||
model.backward(loss)
|
||||
else:
|
||||
|
@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
if model_name == 'bert':
|
||||
data, label = data.cuda(), label.cuda()
|
||||
run_bert_fwd_bwd(model, data, label, False)
|
||||
run_bert_fwd_bwd(zero_model, data, label, False)
|
||||
run_fwd_bwd_no_criterion(model, data, label, False)
|
||||
run_fwd_bwd_no_criterion(zero_model, data, label, False)
|
||||
else:
|
||||
data, label = data.half().cuda(), label.cuda()
|
||||
run_fwd_bwd(model, data, label, criterion, False)
|
||||
|
|
Loading…
Reference in New Issue