mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] tests for 3d parallel (#4493)
parent
59e252ecdb
commit
e04436a82a
|
@ -245,7 +245,6 @@ def check_grad(org_model: Module,
|
|||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
|
||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||
|
|
|
@ -120,12 +120,40 @@ def run_bert_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_bert_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_bert(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_bert_test()
|
||||
|
||||
|
||||
def check_bert_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_bert_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -133,5 +161,13 @@ def test_bert():
|
|||
spawn(check_bert, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bert_3d():
|
||||
spawn(check_bert_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bert()
|
||||
test_bert_3d()
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
@ -118,6 +119,29 @@ def run_bloom_test(test_config):
|
|||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_bloom_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -127,6 +151,12 @@ def check_bloom(rank, world_size, port):
|
|||
run_bloom_test()
|
||||
|
||||
|
||||
def check_bloom_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_bloom_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -134,5 +164,13 @@ def test_bloom():
|
|||
spawn(check_bloom, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bloom_3d():
|
||||
spawn(check_bloom_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bloom()
|
||||
test_bloom_3d()
|
||||
|
|
|
@ -145,12 +145,39 @@ def run_chatglm_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_chatglm_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_chatglm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_chatglm_test()
|
||||
|
||||
|
||||
def check_chatglm_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_chatglm_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -158,5 +185,13 @@ def test_chatglm():
|
|||
spawn(check_chatglm, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_chatglm_3d():
|
||||
spawn(check_chatglm_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatglm()
|
||||
test_chatglm_3d()
|
||||
|
|
|
@ -141,12 +141,40 @@ def run_gpt2_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_gpt2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt2_test()
|
||||
|
||||
|
||||
def check_gpt2_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt2_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -154,5 +182,13 @@ def test_gpt2():
|
|||
spawn(check_gpt2, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2_3d():
|
||||
spawn(check_gpt2_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt2()
|
||||
test_gpt2_3d()
|
||||
|
|
|
@ -56,7 +56,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# unwrap model
|
||||
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
|
||||
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
|
||||
|
||||
# check grad
|
||||
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
|
||||
col_layer_for_check = ['layers[0].self_attn.o_proj']
|
||||
|
@ -156,12 +155,40 @@ def run_llama_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_llama_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_llama(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_test()
|
||||
|
||||
|
||||
def check_llama_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_llama_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -169,5 +196,13 @@ def test_llama():
|
|||
spawn(check_llama, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_llama_3d():
|
||||
spawn(check_llama_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama()
|
||||
test_llama_3d()
|
||||
|
|
|
@ -146,12 +146,39 @@ def run_opt_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_opt_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_OPTModel(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_opt_test()
|
||||
|
||||
|
||||
def check_opt_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_opt_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -159,5 +186,13 @@ def test_OPTModel():
|
|||
spawn(check_OPTModel, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_opt_3d():
|
||||
spawn(check_opt_3d, 8)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_OPTModel()
|
||||
test_opt_3d()
|
||||
|
|
|
@ -137,12 +137,39 @@ def run_t5_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_t5_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_t5(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_t5_test()
|
||||
|
||||
|
||||
def check_t5_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_t5_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -150,5 +177,13 @@ def test_t5():
|
|||
spawn(check_t5, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_t5_3d():
|
||||
spawn(check_t5_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_t5()
|
||||
test_t5_3d()
|
||||
|
|
|
@ -146,12 +146,39 @@ def run_vit_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_vit_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_vit(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_vit_test()
|
||||
|
||||
|
||||
def check_vit_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_vit_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -159,5 +186,13 @@ def test_vit():
|
|||
spawn(check_vit, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_vit_3d():
|
||||
spawn(check_vit_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_vit()
|
||||
test_vit_3d()
|
||||
|
|
|
@ -82,8 +82,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
|
||||
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
||||
check_grad(whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
||||
check_grad(whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
|
||||
|
||||
# check weights after optimizer.step()
|
||||
org_optimizer.step()
|
||||
|
@ -99,7 +99,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=0,
|
||||
dim=1,
|
||||
verbose=False)
|
||||
check_weight(whisper,
|
||||
sharded_whisper,
|
||||
|
@ -155,12 +155,39 @@ def run_whisper_test(test_config):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('test_config', [
|
||||
{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
'num_microbatches': 4,
|
||||
'enable_all_optimization': False,
|
||||
'use_lazy_init': False,
|
||||
'precision': 'fp32',
|
||||
'initial_scale': 1,
|
||||
},
|
||||
])
|
||||
def run_whisper_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
||||
clear_layout_converter()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_whisper(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_whisper_test()
|
||||
|
||||
|
||||
def check_whisper_3d(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_whisper_3d_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
@ -168,5 +195,13 @@ def test_whisper():
|
|||
spawn(check_whisper, 4)
|
||||
|
||||
|
||||
@pytest.mark.largedist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_whisper_3d():
|
||||
spawn(check_whisper_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_whisper()
|
||||
test_whisper_3d()
|
||||
|
|
Loading…
Reference in New Issue