|
|
|
import os, shutil
|
|
|
|
import torch
|
|
|
|
import pytest
|
|
|
|
from copy import deepcopy
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
|
from torch.optim.lr_scheduler import MultiplicativeLR
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use
|
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from colossalai.utils import free_port
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup
|
|
|
|
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
|
|
|
|
|
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
weight.set_process_group(pg)
|
|
|
|
weight.set_tensor_spec(*spec)
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_col_linear(weight, pg):
|
|
|
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
weight.set_process_group(pg)
|
|
|
|
weight.set_tensor_spec(*spec)
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_row_embedding(weight, pg):
|
|
|
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
weight.set_process_group(pg)
|
|
|
|
weight.set_tensor_spec(*spec)
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_col_embedding(weight, pg):
|
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
weight.set_process_group(pg)
|
|
|
|
weight.set_tensor_spec(*spec)
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
for name, p in model.named_parameters():
|
|
|
|
if not isinstance(p, ColoTensor):
|
|
|
|
continue
|
|
|
|
if 'embed' in name and 'weight' in name:
|
|
|
|
init_1d_col_embedding(p, pg)
|
|
|
|
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
|
|
|
init_1d_col_linear(p, pg)
|
|
|
|
if 'proj2' in name and 'weight' in name:
|
|
|
|
init_1d_row_linear(p, pg)
|
|
|
|
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
|
|
|
init_1d_col_linear(p, pg)
|
|
|
|
|
|
|
|
|
|
|
|
def check_param_equal(model, torch_model):
|
|
|
|
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
|
|
|
|
assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
|
|
|
|
|
|
|
|
|
|
|
|
def remove(path):
|
|
|
|
""" param <path> could either be relative or absolute. """
|
|
|
|
if os.path.isfile(path) or os.path.islink(path):
|
|
|
|
os.remove(path)
|
|
|
|
elif os.path.isdir(path):
|
|
|
|
shutil.rmtree(path)
|
|
|
|
else:
|
|
|
|
raise ValueError("file {} is not a file or dir.".format(path))
|
|
|
|
|
|
|
|
|
|
|
|
def compare_optims(optim1, optim2):
|
|
|
|
state1 = optim1.state_dict()['state']
|
|
|
|
state2 = optim2.state_dict()['state']
|
|
|
|
for k, p1 in state1.items():
|
|
|
|
if k not in state2:
|
|
|
|
continue
|
|
|
|
p2 = state2[k]
|
|
|
|
for n, t1 in p1.items():
|
|
|
|
if n not in p2:
|
|
|
|
continue
|
|
|
|
t2 = p2[n]
|
|
|
|
if isinstance(t1, ColoTensor):
|
|
|
|
assert isinstance(t2, ColoTensor)
|
|
|
|
assert torch.allclose(t1, t2, rtol=0, atol=0)
|
|
|
|
|
|
|
|
|
|
|
|
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
|
|
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
|
|
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
|
|
|
|
|
|
|
rank = torch.distributed.get_rank()
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
|
|
|
|
# set_seed(1)
|
|
|
|
with ColoInitContext(device=get_current_device()):
|
|
|
|
model = model_builder(checkpoint=True)
|
|
|
|
|
|
|
|
if use_mp_reload:
|
|
|
|
if 'bert' == model_name:
|
|
|
|
for name, p in model.named_parameters():
|
|
|
|
if not isinstance(p, ColoTensor):
|
|
|
|
continue
|
|
|
|
# num_class = type_vocab_size = 2 | (8, 2)
|
|
|
|
if 'classifier' in name and 'weight' in name:
|
|
|
|
init_1d_row_linear(p, pg)
|
|
|
|
# num_class = vocab_size = 30524 | (30524, 8)
|
|
|
|
elif 'word_embeddings' in name and 'weight' in name:
|
|
|
|
init_1d_row_embedding(p, pg)
|
|
|
|
# num_class = seq_len = 512 | (512, 8)
|
|
|
|
elif 'position_embeddings' in name and 'weight' in name:
|
|
|
|
init_1d_row_embedding(p, pg)
|
|
|
|
# num_class = type_vocab_size = 2 | (2, 8)
|
|
|
|
elif 'token_type_embeddings' in name and 'weight' in name:
|
|
|
|
init_1d_col_embedding(p, pg)
|
|
|
|
elif p.process_group.tp_world_size() == 1:
|
|
|
|
p.set_process_group(pg)
|
|
|
|
elif "simple_net" == model_name:
|
|
|
|
init_spec_func(model, pg)
|
|
|
|
|
|
|
|
model_reload = deepcopy(model)
|
|
|
|
model = model.cuda()
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
model_reload = model_reload.cuda()
|
|
|
|
model_reload.eval()
|
|
|
|
|
|
|
|
opt_class = torch.optim.Adam
|
|
|
|
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
|
|
|
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
|
|
|
|
|
|
|
for i, (data, label) in enumerate(train_dataloader):
|
|
|
|
|
|
|
|
# Zero grad
|
|
|
|
colo_optimizer.zero_grad()
|
|
|
|
colo_optimizer_reload.zero_grad()
|
|
|
|
|
|
|
|
data = data.to(get_current_device())
|
|
|
|
label = label.to(get_current_device())
|
|
|
|
|
|
|
|
dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
|
|
|
|
dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())
|
|
|
|
|
|
|
|
# Bcast rank0 data to all processes
|
|
|
|
if criterion:
|
|
|
|
output = model(data)
|
|
|
|
output_reload = model_reload(data)
|
|
|
|
loss = criterion(output, label)
|
|
|
|
loss_reload = criterion(output_reload, label)
|
|
|
|
else:
|
|
|
|
loss = model(data, label)
|
|
|
|
loss_reload = model_reload(data, label)
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
loss_reload.backward()
|
|
|
|
|
|
|
|
colo_optimizer.step()
|
|
|
|
colo_optimizer_reload.step()
|
|
|
|
|
|
|
|
if i > 2:
|
|
|
|
break
|
|
|
|
|
|
|
|
if not os.path.isdir('./checkpoint') and rank == 0:
|
|
|
|
os.mkdir('./checkpoint')
|
|
|
|
dist.barrier()
|
|
|
|
|
|
|
|
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
|
|
|
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
|
|
|
|
|
|
|
check_param_equal(model, model_reload)
|
|
|
|
compare_optims(colo_optimizer, colo_optimizer_reload)
|
|
|
|
|
|
|
|
if rank == 0:
|
|
|
|
remove('./checkpoint')
|
|
|
|
dist.barrier()
|
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
|
|
pg = ProcessGroup(tp_degree=world_size)
|
|
|
|
|
|
|
|
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
|
|
|
for model_name in ['bert']:
|
|
|
|
_run_checkpoint(model_name,
|
|
|
|
init_1d_row_for_linear_weight_spec,
|
|
|
|
use_ddp,
|
|
|
|
use_mp_reload,
|
|
|
|
test_scheduler=test_scheduler,
|
|
|
|
pg=pg)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@pytest.mark.parametrize('world_size', [1, 2])
|
|
|
|
@pytest.mark.parametrize('use_ddp', [False])
|
|
|
|
@pytest.mark.parametrize('use_mp_reload', [True, False])
|
|
|
|
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
|
|
|
|
run_func = partial(run_dist,
|
|
|
|
world_size=world_size,
|
|
|
|
port=free_port(),
|
|
|
|
use_ddp=use_ddp,
|
|
|
|
use_mp_reload=use_mp_reload,
|
|
|
|
test_scheduler=test_scheduler)
|
|
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")
|