[Tensor] add optimizer to bert test (#933)

* add optimizer to bert test

* polish
pull/944/head
Ziyue Jiang 2022-05-13 11:37:23 +08:00 committed by GitHub
parent 7edb38193a
commit 830d3bca26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 40 additions and 9 deletions

View File

@ -95,6 +95,15 @@ def run_1d_hybrid_tp(model_name):
set_seed(1)
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)
# Make two models have the same init params
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data)
if 'bert' == model_name:
parallel_action_list_row = [
@ -176,14 +185,15 @@ def run_1d_hybrid_tp(model_name):
if 'classifier' in name and ('weight' in name or 'bias' in name):
p.set_spec(spec_classifier_col)
set_seed(1)
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
model = model.cuda()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
for i, (data, label) in enumerate(train_dataloader):
model.eval()
colo_optimizer.zero_grad()
if rank == 0:
model_torch.eval()
colo_optimizer_torch.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
@ -210,12 +220,33 @@ def run_1d_hybrid_tp(model_name):
if rank == 0:
# print(loss.torch_tensor().item())
# print('loss torch', loss_torch.item())
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
with torch.no_grad():
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
loss.backward()
colo_optimizer.step()
if rank == 0:
loss_torch.backward()
colo_optimizer_torch.step()
with torch.no_grad():
# check param
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
if p1.size() == p2.size():
assert torch.allclose(p1, p2)
else:
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
if p1.size(-1) < p2.size(-1): # col
world_size = p2.size(-1) // p1.size(-1)
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
elif p1.size(0) < p2.size(0): # row
world_size = p2.size(0) // p1.size(0)
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
assert torch.allclose(p1, split_p2)
if i > 5:
break
@ -428,5 +459,5 @@ def _test_pretrain_load(world_size):
if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
# test_model()
_test_pretrain_load(4)
test_model(4)
# _test_pretrain_load(4)