[chatgpt] add test checkpoint (#2797)

* [chatgpt] add test checkpoint

* [chatgpt] test checkpoint use smaller model
pull/2822/head
ver217 2023-02-20 15:22:36 +08:00 committed by GitHub
parent c008d4ad0c
commit b6a108cb91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 2 deletions

View File

@ -0,0 +1,98 @@
import os
import tempfile
from contextlib import nullcontext
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from chatgpt.nn import GPTActor
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def run_test_checkpoint(strategy):
BATCH_SIZE = 2
if strategy == 'ddp':
strategy = DDPStrategy()
elif strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
with strategy.model_init_context():
actor = GPTActor(config=GPT_CONFIG).cuda()
actor_optim = HybridAdam(actor.parameters())
actor, actor_optim = strategy.prepare((actor, actor_optim))
def run_step():
data = get_data(BATCH_SIZE)
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim)
run_step()
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with ctx as dirname:
rank0_dirname = [dirname]
dist.broadcast_object_list(rank0_dirname)
rank0_dirname = rank0_dirname[0]
model_path = os.path.join(rank0_dirname, 'model.pt')
optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
strategy.save_model(actor, model_path, only_rank0=True)
strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
dist.barrier()
strategy.load_model(actor, model_path, strict=False)
strategy.load_optimizer(actor_optim, optim_path)
dist.barrier()
run_step()
def run_dist(rank, world_size, port, strategy):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
run_test_checkpoint(strategy)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_checkpoint(2, 'colossalai_zero2')

View File

@ -10,10 +10,13 @@ from chatgpt.experience_maker import NaiveExperienceMaker
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.replay_buffer import NaiveReplayBuffer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
def get_data(batch_size: int, seq_len: int = 10) -> dict:
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
@ -42,8 +45,8 @@ def run_test_data(strategy):
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
actor = GPTActor().cuda()
critic = GPTCritic().cuda()
actor = GPTActor(config=GPT_CONFIG).cuda()
critic = GPTCritic(config=GPT_CONFIG).cuda()
initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model)).cuda()