mirror of https://github.com/hpcaitech/ColossalAI
[checkpoint] add test for bert and hotfix save bugs (#1297)
parent
bd71e2a88b
commit
3ef3791a3b
|
@ -28,7 +28,8 @@ def save_checkpoint(dire: str,
|
|||
if isinstance(v, ColoTensor):
|
||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
||||
new_dict[k] = v.to_replicate().detach()
|
||||
|
||||
else:
|
||||
new_dict[k] = v
|
||||
if dist.get_rank() == 0:
|
||||
for k, v in new_dict.items():
|
||||
if isinstance(v, ColoTensor):
|
||||
|
@ -60,7 +61,7 @@ def load_checkpoint(dire,
|
|||
"""
|
||||
|
||||
mapping = dict()
|
||||
for k, v in model.named_parameters():
|
||||
for k, v in model.state_dict().items():
|
||||
if isinstance(v, ColoTensor):
|
||||
mapping[k] = (v.dist_spec, v.compute_spec)
|
||||
v.to_replicate_()
|
||||
|
@ -70,6 +71,6 @@ def load_checkpoint(dire,
|
|||
|
||||
# reset tensors to original dist spec.
|
||||
with DistSpecManager.no_grad():
|
||||
for k, v in model.named_parameters():
|
||||
for k, v in model.state_dict().items():
|
||||
if isinstance(v, ColoTensor):
|
||||
v.set_tensor_spec(*mapping[k])
|
||||
|
|
|
@ -1,91 +1,65 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import os, shutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.optim.lr_scheduler import MultiplicativeLR
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
class DummyDataGenerator(ABC):
|
||||
|
||||
def __init__(self, length=10):
|
||||
self.length = length
|
||||
|
||||
@abstractmethod
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
self.step = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.step < self.length:
|
||||
self.step += 1
|
||||
return self.generate()
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
def __init__(self, batch_size, category, feature_size, length=10):
|
||||
super().__init__(length)
|
||||
self.batch_size = batch_size
|
||||
self.category = category
|
||||
self.feature_size = feature_size
|
||||
|
||||
def generate(self):
|
||||
image_dict = {}
|
||||
image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1
|
||||
image_dict['label'] = torch.randint(self.category, (self.batch_size,),
|
||||
dtype=torch.int64,
|
||||
device=get_current_device())
|
||||
return image_dict
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
def __init__(self, in_features, out_features, hidden_features=None):
|
||||
super().__init__()
|
||||
if hidden_features is None:
|
||||
hidden_features = out_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
weight.set_process_group(pg)
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n:
|
||||
p.set_process_group(pg)
|
||||
p.set_tensor_spec(*spec)
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p, pg)
|
||||
if 'proj2' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||
init_1d_col_linear(p, pg)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
|
@ -103,56 +77,75 @@ def remove(path):
|
|||
raise ValueError("file {} is not a file or dir.".format(path))
|
||||
|
||||
|
||||
def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||
num_epoch = 5
|
||||
warmup_epoch = 2
|
||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
batch = 3
|
||||
feature = 32
|
||||
category = 16
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = MLP(feature, category)
|
||||
model = model_builder(checkpoint=True)
|
||||
model_reload = model_builder(checkpoint=True)
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model_reload = MLP(feature, category)
|
||||
if use_mp_reload:
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
elif 'word_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
elif 'position_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
elif p.process_group.tp_world_size() == 1:
|
||||
p.redistribute(ReplicaSpec(), pg)
|
||||
elif "simple_net" == model_name:
|
||||
init_spec_func(model, pg)
|
||||
|
||||
model = model.cuda()
|
||||
model.train()
|
||||
|
||||
model_reload = model_reload.cuda()
|
||||
if use_ddp:
|
||||
model = ColoDDP(model, pg)
|
||||
model_reload = ColoDDP(model_reload, pg)
|
||||
model_reload.train()
|
||||
|
||||
init_spec_func(model, pg)
|
||||
if use_mp_reload:
|
||||
init_spec_func(model_reload, pg)
|
||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
||||
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
|
||||
lr=0.001,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-08,
|
||||
weight_decay=0)
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
lr_scheduler = None
|
||||
if test_scheduler == 'colossalai_cosine_warmup':
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
|
||||
total_steps=num_epoch,
|
||||
warmup_steps=warmup_epoch)
|
||||
elif test_scheduler == 'torch_cosine':
|
||||
lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
|
||||
lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
|
||||
elif test_scheduler == 'torch_lambda':
|
||||
lr_lambda = lambda epoch: 0.95
|
||||
lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
|
||||
lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
|
||||
else:
|
||||
raise TypeError(f"{test_scheduler} is invalid")
|
||||
# Zero grad
|
||||
colo_optimizer.zero_grad()
|
||||
|
||||
save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler)
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
# Bcast rank0 data to all processes
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
loss.backward()
|
||||
colo_optimizer.step()
|
||||
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||
os.mkdir('./checkpoint')
|
||||
save_checkpoint('./checkpoint', 0, model, None, None)
|
||||
dist.barrier()
|
||||
load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload)
|
||||
load_checkpoint('./checkpoint', 0, model_reload, None, None)
|
||||
|
||||
# Since model is sharded, we merge them before param checking.
|
||||
for p in model.parameters():
|
||||
|
@ -163,26 +156,29 @@ def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
|||
|
||||
check_param_equal(model, model_reload)
|
||||
|
||||
if rank == 0:
|
||||
remove('./checkpoint')
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg)
|
||||
for model_name in ['bert', 'simple_net']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
use_mp_reload,
|
||||
test_scheduler=test_scheduler,
|
||||
pg=pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@pytest.mark.parametrize('use_ddp', [True, False])
|
||||
@pytest.mark.parametrize('use_ddp', [False])
|
||||
@pytest.mark.parametrize('use_mp_reload', [True, False])
|
||||
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
|
||||
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
|
||||
if not os.path.isdir('./checkpoint'):
|
||||
os.mkdir('./checkpoint')
|
||||
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
|
||||
run_func = partial(run_dist,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
|
@ -190,8 +186,7 @@ def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
|
|||
use_mp_reload=use_mp_reload,
|
||||
test_scheduler=test_scheduler)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
remove('./checkpoint')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(2, True, False, "torch_cosine")
|
||||
test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")
|
||||
|
|
Loading…
Reference in New Issue