[Gemini] more tests for Gemini (#2038)

* [Gemini] more tests for Gemini

* polish code
pull/2039/head
Jiarui Fang 2022-11-29 17:13:10 +08:00 committed by GitHub
parent 537e181705
commit eb7742a4bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 26 deletions

View File

@ -40,7 +40,7 @@ def get_training_components():
num_layer = 2
vocab_size = 32
def bert_model_builder(checkpoint):
def bert_model_builder(checkpoint: bool = False):
config = BertConfig(vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,

View File

@ -18,8 +18,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed, tensor_equal, tensor_shard_equal
from tests.test_tensor.common_utils import set_seed
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
@ -37,19 +38,16 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key)
def run_fwd_bwd(model, criterion, optimizer, input_ids):
optimizer.zero_grad()
logits = model(input_ids)
logits = logits.float()
loss = criterion(logits, input_ids)
optimizer.backward(loss)
return logits
# 'gpt2', 'bert',
TEST_MODELS = ['gpt2', 'bert']
# TEST_MODELS = ['simple_net']
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
@parameterize('model_name', TEST_MODELS)
def exam_model_step(placement_policy, model_name: str):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
@ -87,9 +85,13 @@ def exam_gpt_fwd_bwd(placement_policy):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False)
loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True)
assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}"
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
@ -99,9 +101,10 @@ def exam_gpt_fwd_bwd(placement_policy):
@parameterize('placement_policy', ['cuda', 'cpu'])
def exam_tiny_example(placement_policy):
@parameterize('model_name', TEST_MODELS)
def exam_tiny_example(placement_policy, model_name: str):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
@ -131,9 +134,13 @@ def exam_tiny_example(placement_policy):
if i > 2:
break
zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids)
assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2)
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=False)
loss = run_fwd_bwd(model, input_ids.cuda(), label.cuda(), criterion, use_init_ctx=True)
assert torch.allclose(torch_loss, loss, rtol=1e-3, atol=1e-2), f"{torch_loss} vs {loss}"
# debug_print([0], zero_logits, torch_logits)
zero_optim.step()
@ -145,17 +152,17 @@ def exam_tiny_example(placement_policy):
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_gpt_fwd_bwd()
exam_model_step()
exam_tiny_example()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
def test_optim(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_gpt(2)
test_optim(2)

View File

@ -19,9 +19,10 @@ from tests.test_tensor.common_utils import debug_print, set_seed
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_state_dict(placement_policy, keep_gathered):
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict(placement_policy, keep_gathered, model_name: str):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
@ -53,9 +54,10 @@ def exam_state_dict(placement_policy, keep_gathered):
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
def exam_load_state_dict(placement_policy, keep_gathered):
@parameterize('model_name', ['gpt2', 'bert'])
def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):