From b6a108cb916417abf339fc41a794d77222135ced Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 20 Feb 2023 15:22:36 +0800 Subject: [PATCH] [chatgpt] add test checkpoint (#2797) * [chatgpt] add test checkpoint * [chatgpt] test checkpoint use smaller model --- applications/ChatGPT/tests/test_checkpoint.py | 98 +++++++++++++++++++ applications/ChatGPT/tests/test_data.py | 7 +- 2 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 applications/ChatGPT/tests/test_checkpoint.py diff --git a/applications/ChatGPT/tests/test_checkpoint.py b/applications/ChatGPT/tests/test_checkpoint.py new file mode 100644 index 000000000..6cbe51569 --- /dev/null +++ b/applications/ChatGPT/tests/test_checkpoint.py @@ -0,0 +1,98 @@ +import os +import tempfile +from contextlib import nullcontext +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from chatgpt.nn import GPTActor +from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + +GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) + + +def get_data(batch_size: int, seq_len: int = 10) -> dict: + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + attention_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def run_test_checkpoint(strategy): + BATCH_SIZE = 2 + + if strategy == 'ddp': + strategy = DDPStrategy() + elif strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5) + elif strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{strategy}"') + + with strategy.model_init_context(): + actor = GPTActor(config=GPT_CONFIG).cuda() + + actor_optim = HybridAdam(actor.parameters()) + + actor, actor_optim = strategy.prepare((actor, actor_optim)) + + def run_step(): + data = get_data(BATCH_SIZE) + action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool) + action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask']) + loss = action_log_probs.sum() + strategy.backward(loss, actor, actor_optim) + strategy.optimizer_step(actor_optim) + + run_step() + + ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext() + + with ctx as dirname: + rank0_dirname = [dirname] + dist.broadcast_object_list(rank0_dirname) + rank0_dirname = rank0_dirname[0] + + model_path = os.path.join(rank0_dirname, 'model.pt') + optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt') + + strategy.save_model(actor, model_path, only_rank0=True) + strategy.save_optimizer(actor_optim, optim_path, only_rank0=False) + + dist.barrier() + + strategy.load_model(actor, model_path, strict=False) + strategy.load_optimizer(actor_optim, optim_path) + + dist.barrier() + + run_step() + + +def run_dist(rank, world_size, port, strategy): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + run_test_checkpoint(strategy) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size, strategy): + run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_checkpoint(2, 'colossalai_zero2') diff --git a/applications/ChatGPT/tests/test_data.py b/applications/ChatGPT/tests/test_data.py index b0a9433c2..9571c2843 100644 --- a/applications/ChatGPT/tests/test_data.py +++ b/applications/ChatGPT/tests/test_data.py @@ -10,10 +10,13 @@ from chatgpt.experience_maker import NaiveExperienceMaker from chatgpt.nn import GPTActor, GPTCritic, RewardModel from chatgpt.replay_buffer import NaiveReplayBuffer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy +from transformers.models.gpt2.configuration_gpt2 import GPT2Config from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port +GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) + def get_data(batch_size: int, seq_len: int = 10) -> dict: input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') @@ -42,8 +45,8 @@ def run_test_data(strategy): else: raise ValueError(f'Unsupported strategy "{strategy}"') - actor = GPTActor().cuda() - critic = GPTCritic().cuda() + actor = GPTActor(config=GPT_CONFIG).cuda() + critic = GPTCritic(config=GPT_CONFIG).cuda() initial_model = deepcopy(actor) reward_model = RewardModel(deepcopy(critic.model)).cuda()