diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index 02f877c6a..b7f82db83 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -1 +1,2 @@ from . import bert, gpt, inline_op_model, nested_model, no_leaf_module, repeated_computed_layer, resnet, simple_net +from .utils import run_fwd_bwd diff --git a/tests/components_to_test/gpt.py b/tests/components_to_test/gpt.py index 3123211ad..fe25b4923 100644 --- a/tests/components_to_test/gpt.py +++ b/tests/components_to_test/gpt.py @@ -1,10 +1,12 @@ import torch import torch.nn as nn -from .registry import non_distributed_component_funcs from transformers import GPT2Config, GPT2LMHeadModel -from .utils.dummy_data_generator import DummyDataGenerator + from colossalai.utils.cuda import get_current_device +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + class DummyDataLoader(DummyDataGenerator): vocab_size = 128 @@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator): input_ids = torch.randint(0, DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), device=get_current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask + return input_ids, input_ids class GPTLMModel(nn.Module): @@ -43,8 +44,9 @@ class GPTLMModel(nn.Module): if checkpoint: self.model.gradient_checkpointing_enable() - def forward(self, input_ids, attention_mask): + def forward(self, input_ids): # Only return lm_logits + attention_mask = torch.ones_like(input_ids) return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py index 4fb7e55b2..a8d47d6af 100644 --- a/tests/components_to_test/inline_op_model.py +++ b/tests/components_to_test/inline_op_model.py @@ -38,7 +38,7 @@ class DummyDataLoader(DummyDataGenerator): return data, label -@non_distributed_component_funcs.register(name='inline_op_module') +@non_distributed_component_funcs.register(name='inline_op_model') def get_training_components(): def model_builder(checkpoint=True): diff --git a/tests/components_to_test/utils/__init__.py b/tests/components_to_test/utils/__init__.py index fc6321214..f223f7d32 100644 --- a/tests/components_to_test/utils/__init__.py +++ b/tests/components_to_test/utils/__init__.py @@ -1 +1,2 @@ -from .dummy_data_generator import DummyDataGenerator +from .dummy_data_generator import DummyDataGenerator +from .executor import run_fwd_bwd diff --git a/tests/components_to_test/utils/executor.py b/tests/components_to_test/utils/executor.py new file mode 100644 index 000000000..acb6a2134 --- /dev/null +++ b/tests/components_to_test/utils/executor.py @@ -0,0 +1,15 @@ +import torch + + +def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, use_init_ctx=False): + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + if use_init_ctx: + model.backward(loss) + else: + loss.backward() diff --git a/tests/test_gemini/test_gemini_train.py b/tests/test_gemini/test_gemini_train.py new file mode 100644 index 000000000..1a8821bdd --- /dev/null +++ b/tests/test_gemini/test_gemini_train.py @@ -0,0 +1,67 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.parallel import ZeroDDP +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test import run_fwd_bwd +from tests.components_to_test.registry import non_distributed_component_funcs + + +def run_gemini_fwd_bwd(rank, world_size, port, model_name: str, iter_num=2): + PLACEMENT_POLICY = 'cuda' + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + # build torch model + model_torch = model_builder(checkpoint=False).cuda() + + for i, (data, label) in enumerate(train_dataloader): + if i >= iter_num: + break + run_fwd_bwd(model_torch, data.cuda(), label.cuda(), criterion, False, use_init_ctx=False) + + # build CAI model + with ColoInitContext(device=get_current_device()): + model = model_builder(checkpoint=False) + + from colossalai.gemini import ChunkManager, GeminiManager, search_chunk_configuration + config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict, init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) + model = ZeroDDP(model, gemini_manager) + + model.train() + + for i, (data, label) in enumerate(train_dataloader): + if i >= iter_num: + break + run_fwd_bwd(model, data.cuda(), label.cuda(), criterion, False, use_init_ctx=True) + + for p1, p2 in zip(model.parameters(), model_torch.parameters()): + torch.allclose(p1.to(torch.float), p2.to(torch.float)) + print(f'pass test {model_name}') + + +@pytest.mark.parametrize("model_name", ['bert']) +@rerun_if_address_is_in_use() +def test_gemini_train(model_name, iter_num=2): + run_func = partial(run_gemini_fwd_bwd, world_size=1, port=free_port(), model_name=model_name, iter_num=iter_num) + mp.spawn(run_func, nprocs=1) + + +if __name__ == '__main__': + # for model_name in ["bert", "resnet18", "inline_op_model"]: + # bert, gpt, inline_op_model, nested_model, no_leaf_module, + # repeated_computed_layer, resnet, simple_net + for model_name in ["nested_model", "no_leaf_module"]: + test_gemini_train(model_name=model_name, iter_num=4) diff --git a/tests/test_gemini/test_mem_tracer.py b/tests/test_gemini/test_mem_tracer.py index 7e524765b..5672f0439 100644 --- a/tests/test_gemini/test_mem_tracer.py +++ b/tests/test_gemini/test_mem_tracer.py @@ -8,20 +8,10 @@ import colossalai from colossalai.gemini.memory_tracer import MemtracerWrapper from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port +from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - model.backward(loss) - - def run_tracer(rank, world_size, port, use_grad_check=True): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'bert'] @@ -43,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True): data = data.cuda() label = label.cuda() - run_fwd_bwd(model, data, label, criterion, False) + run_fwd_bwd(model, data, label, criterion, False, use_init_ctx=False) model._ophook_list[0].print_non_model_data() @@ -58,4 +48,4 @@ def test_tracer(world_size, use_grad_check): if __name__ == '__main__': - test_tracer(1) + test_tracer(1, True) diff --git a/tests/test_gemini/test_param_op.py b/tests/test_gemini/test_param_op.py index f8f7c34d0..60a0833cf 100644 --- a/tests/test_gemini/test_param_op.py +++ b/tests/test_gemini/test_param_op.py @@ -50,7 +50,7 @@ def run_model(model, inputs, label, criterion, use_param_hook=False): def test_base_param_hook(): - test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'inline_op_module'] + test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'inline_op_model'] # test_models = ['bert'] for model_name in test_models: diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index 0a2db2a17..7391ffc7d 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -30,9 +30,9 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item()) -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): +def run_fwd_bwd(model, criterion, optimizer, input_ids): optimizer.zero_grad() - logits = model(input_ids, attn_mask) + logits = model(input_ids) logits = logits.float() loss = criterion(logits, input_ids) optimizer.backward(loss) @@ -71,16 +71,16 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather): torch_model.eval() set_seed(pg.dp_local_rank()) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): if i > 0: break - logits = model(input_ids, attn_mask) + logits = model(input_ids) logits = logits.float() loss = criterion(logits, input_ids) model.backward(loss) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index 008813698..eec1db6e7 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -37,9 +37,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module): assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): +def run_fwd_bwd(model, criterion, optimizer, input_ids): optimizer.zero_grad() - logits = model(input_ids, attn_mask) + logits = model(input_ids) logits = logits.float() loss = criterion(logits, input_ids) optimizer.backward(loss) @@ -83,12 +83,12 @@ def exam_gpt_fwd_bwd(placement_policy): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) # debug_print([0], zero_logits, torch_logits) @@ -127,12 +127,12 @@ def exam_tiny_example(placement_policy): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) # debug_print([0], zero_logits, torch_logits) diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index 68885e543..7f53415bf 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -50,11 +50,11 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): set_seed(dist.get_rank() * 3 + 128) model.train() - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): if i > 0: break optim.zero_grad() - logits = model(input_ids, attn_mask) + logits = model(input_ids) logits = logits.float() loss = criterion(logits, input_ids) optim.backward(loss) diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index 6f2ef9fa8..ad8ac87b2 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -1,21 +1,26 @@ -import pytest - from functools import partial -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed +import pytest import torch -from torch.nn.parallel import DistributedDataParallel as DDP import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, debug_print +from tests.test_tensor.common_utils import ( + debug_print, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_equal, + tensor_shard_equal, +) def init_1d_row_spec(model, pg: ProcessGroup): @@ -107,10 +112,10 @@ def run_gpt(init_spec_func, use_ddp): torch_model.eval() set_seed(pg.dp_local_rank()) torch.distributed.barrier() - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - logits = model(colo_input, attn_mask) - torch_logits = torch_model(input_ids, attn_mask) + logits = model(colo_input) + torch_logits = torch_model(input_ids) assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}" loss = criterion(logits, input_ids) torch_loss = criterion(torch_logits, input_ids) diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index b87802191..33db676cb 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -36,9 +36,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): "parameter '{}' has problem.".format(key) -def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): +def run_fwd_bwd(model, criterion, optimizer, input_ids): optimizer.zero_grad() - logits = model(input_ids, attn_mask) + logits = model(input_ids) logits = logits.float() loss = criterion(logits, input_ids) optimizer.backward(loss) @@ -117,12 +117,12 @@ def run_gpt(placement_policy, tp_init_spec_func=None): torch_model.eval() set_seed(pg.dp_local_rank()) - for i, (input_ids, attn_mask) in enumerate(train_dataloader): + for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask) - torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids) assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) zero_optim.step()