[Gemini] add unitests to check gemini correctness (#2015)

pull/2021/merge
Jiarui Fang 2022-11-24 16:51:45 +08:00 committed by GitHub
parent 0b0d8f9e17
commit 2e9cbfca12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 135 additions and 54 deletions

View File

@ -1 +1,2 @@
from . import bert, gpt, inline_op_model, nested_model, no_leaf_module, repeated_computed_layer, resnet, simple_net from . import bert, gpt, inline_op_model, nested_model, no_leaf_module, repeated_computed_layer, resnet, simple_net
from .utils import run_fwd_bwd

View File

@ -1,10 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .registry import non_distributed_component_funcs
from transformers import GPT2Config, GPT2LMHeadModel from transformers import GPT2Config, GPT2LMHeadModel
from .utils.dummy_data_generator import DummyDataGenerator
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class DummyDataLoader(DummyDataGenerator): class DummyDataLoader(DummyDataGenerator):
vocab_size = 128 vocab_size = 128
@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator):
input_ids = torch.randint(0, input_ids = torch.randint(0,
DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
device=get_current_device()) device=get_current_device())
attention_mask = torch.ones_like(input_ids) return input_ids, input_ids
return input_ids, attention_mask
class GPTLMModel(nn.Module): class GPTLMModel(nn.Module):
@ -43,8 +44,9 @@ class GPTLMModel(nn.Module):
if checkpoint: if checkpoint:
self.model.gradient_checkpointing_enable() self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask): def forward(self, input_ids):
# Only return lm_logits # Only return lm_logits
attention_mask = torch.ones_like(input_ids)
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]

View File

@ -38,7 +38,7 @@ class DummyDataLoader(DummyDataGenerator):
return data, label return data, label
@non_distributed_component_funcs.register(name='inline_op_module') @non_distributed_component_funcs.register(name='inline_op_model')
def get_training_components(): def get_training_components():
def model_builder(checkpoint=True): def model_builder(checkpoint=True):

View File

@ -1 +1,2 @@
from .dummy_data_generator import DummyDataGenerator from .dummy_data_generator import DummyDataGenerator
from .executor import run_fwd_bwd

View File

@ -0,0 +1,15 @@
import torch
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, use_init_ctx=False):
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if use_init_ctx:
model.backward(loss)
else:
loss.backward()

View File

@ -0,0 +1,67 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
def run_gemini_fwd_bwd(rank, world_size, port, model_name: str, iter_num=2):
PLACEMENT_POLICY = 'cuda'
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
# build torch model
model_torch = model_builder(checkpoint=False).cuda()
for i, (data, label) in enumerate(train_dataloader):
if i >= iter_num:
break
run_fwd_bwd(model_torch, data.cuda(), label.cuda(), criterion, False, use_init_ctx=False)
# build CAI model
with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=False)
from colossalai.gemini import ChunkManager, GeminiManager, search_chunk_configuration
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict, init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()
for i, (data, label) in enumerate(train_dataloader):
if i >= iter_num:
break
run_fwd_bwd(model, data.cuda(), label.cuda(), criterion, False, use_init_ctx=True)
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
torch.allclose(p1.to(torch.float), p2.to(torch.float))
print(f'pass test {model_name}')
@pytest.mark.parametrize("model_name", ['bert'])
@rerun_if_address_is_in_use()
def test_gemini_train(model_name, iter_num=2):
run_func = partial(run_gemini_fwd_bwd, world_size=1, port=free_port(), model_name=model_name, iter_num=iter_num)
mp.spawn(run_func, nprocs=1)
if __name__ == '__main__':
# for model_name in ["bert", "resnet18", "inline_op_model"]:
# bert, gpt, inline_op_model, nested_model, no_leaf_module,
# repeated_computed_layer, resnet, simple_net
for model_name in ["nested_model", "no_leaf_module"]:
test_gemini_train(model_name=model_name, iter_num=4)

View File

@ -8,20 +8,10 @@ import colossalai
from colossalai.gemini.memory_tracer import MemtracerWrapper from colossalai.gemini.memory_tracer import MemtracerWrapper
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
model.backward(loss)
def run_tracer(rank, world_size, port, use_grad_check=True): def run_tracer(rank, world_size, port, use_grad_check=True):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'bert']
@ -43,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
data = data.cuda() data = data.cuda()
label = label.cuda() label = label.cuda()
run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(model, data, label, criterion, False, use_init_ctx=False)
model._ophook_list[0].print_non_model_data() model._ophook_list[0].print_non_model_data()
@ -58,4 +48,4 @@ def test_tracer(world_size, use_grad_check):
if __name__ == '__main__': if __name__ == '__main__':
test_tracer(1) test_tracer(1, True)

View File

@ -50,7 +50,7 @@ def run_model(model, inputs, label, criterion, use_param_hook=False):
def test_base_param_hook(): def test_base_param_hook():
test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'inline_op_module'] test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'inline_op_model']
# test_models = ['bert'] # test_models = ['bert']
for model_name in test_models: for model_name in test_models:

View File

@ -30,9 +30,9 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item()) assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): def run_fwd_bwd(model, criterion, optimizer, input_ids):
optimizer.zero_grad() optimizer.zero_grad()
logits = model(input_ids, attn_mask) logits = model(input_ids)
logits = logits.float() logits = logits.float()
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
optimizer.backward(loss) optimizer.backward(loss)
@ -71,16 +71,16 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather):
torch_model.eval() torch_model.eval()
set_seed(pg.dp_local_rank()) set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
if i > 0: if i > 0:
break break
logits = model(input_ids, attn_mask) logits = model(input_ids)
logits = logits.float() logits = logits.float()
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
model.backward(loss) model.backward(loss)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format(
torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits) torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits)

View File

@ -37,9 +37,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): def run_fwd_bwd(model, criterion, optimizer, input_ids):
optimizer.zero_grad() optimizer.zero_grad()
logits = model(input_ids, attn_mask) logits = model(input_ids)
logits = logits.float() logits = logits.float()
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
optimizer.backward(loss) optimizer.backward(loss)
@ -83,12 +83,12 @@ def exam_gpt_fwd_bwd(placement_policy):
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits) # debug_print([0], zero_logits, torch_logits)
@ -127,12 +127,12 @@ def exam_tiny_example(placement_policy):
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
# debug_print([0], zero_logits, torch_logits) # debug_print([0], zero_logits, torch_logits)

View File

@ -50,11 +50,11 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
model.train() model.train()
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
if i > 0: if i > 0:
break break
optim.zero_grad() optim.zero_grad()
logits = model(input_ids, attn_mask) logits = model(input_ids)
logits = logits.float() logits = logits.float()
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
optim.backward(loss) optim.backward(loss)

View File

@ -1,21 +1,26 @@
import pytest
from functools import partial from functools import partial
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
import pytest
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, debug_print from tests.test_tensor.common_utils import (
debug_print,
set_seed,
split_param_col_tp1d,
split_param_row_tp1d,
tensor_equal,
tensor_shard_equal,
)
def init_1d_row_spec(model, pg: ProcessGroup): def init_1d_row_spec(model, pg: ProcessGroup):
@ -107,10 +112,10 @@ def run_gpt(init_spec_func, use_ddp):
torch_model.eval() torch_model.eval()
set_seed(pg.dp_local_rank()) set_seed(pg.dp_local_rank())
torch.distributed.barrier() torch.distributed.barrier()
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
logits = model(colo_input, attn_mask) logits = model(colo_input)
torch_logits = torch_model(input_ids, attn_mask) torch_logits = torch_model(input_ids)
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}" assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
torch_loss = criterion(torch_logits, input_ids) torch_loss = criterion(torch_logits, input_ids)

View File

@ -36,9 +36,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
"parameter '{}' has problem.".format(key) "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): def run_fwd_bwd(model, criterion, optimizer, input_ids):
optimizer.zero_grad() optimizer.zero_grad()
logits = model(input_ids, attn_mask) logits = model(input_ids)
logits = logits.float() logits = logits.float()
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
optimizer.backward(loss) optimizer.backward(loss)
@ -117,12 +117,12 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
torch_model.eval() torch_model.eval()
set_seed(pg.dp_local_rank()) set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask) zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
zero_optim.step() zero_optim.step()