mirror of https://github.com/hpcaitech/ColossalAI
parent
7edb38193a
commit
830d3bca26
|
@ -95,6 +95,15 @@ def run_1d_hybrid_tp(model_name):
|
|||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
if 'bert' == model_name:
|
||||
parallel_action_list_row = [
|
||||
|
@ -176,14 +185,15 @@ def run_1d_hybrid_tp(model_name):
|
|||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
p.set_spec(spec_classifier_col)
|
||||
|
||||
set_seed(1)
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
model.eval()
|
||||
colo_optimizer.zero_grad()
|
||||
if rank == 0:
|
||||
model_torch.eval()
|
||||
colo_optimizer_torch.zero_grad()
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
|
@ -210,12 +220,33 @@ def run_1d_hybrid_tp(model_name):
|
|||
if rank == 0:
|
||||
# print(loss.torch_tensor().item())
|
||||
# print('loss torch', loss_torch.item())
|
||||
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
|
||||
with torch.no_grad():
|
||||
assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
|
||||
if rank == 0:
|
||||
loss_torch.backward()
|
||||
colo_optimizer_torch.step()
|
||||
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
if p1.size() == p2.size():
|
||||
assert torch.allclose(p1, p2)
|
||||
else:
|
||||
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
|
||||
if p1.size(-1) < p2.size(-1): # col
|
||||
world_size = p2.size(-1) // p1.size(-1)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
||||
|
||||
elif p1.size(0) < p2.size(0): # row
|
||||
world_size = p2.size(0) // p1.size(0)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
||||
|
||||
assert torch.allclose(p1, split_p2)
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
||||
|
@ -428,5 +459,5 @@ def _test_pretrain_load(world_size):
|
|||
if __name__ == '__main__':
|
||||
# test_model_parameters()
|
||||
# test_colo_optimizer()
|
||||
# test_model()
|
||||
_test_pretrain_load(4)
|
||||
test_model(4)
|
||||
# _test_pretrain_load(4)
|
||||
|
|
Loading…
Reference in New Issue