mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
4.5 KiB
119 lines
4.5 KiB
import os
|
|
from copy import deepcopy
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
from coati.experience_buffer import NaiveExperienceBuffer
|
|
from coati.experience_maker import NaiveExperienceMaker
|
|
from coati.models.base import RewardModel
|
|
from coati.models.gpt import GPTActor, GPTCritic
|
|
from coati.trainer.strategies import DDPStrategy, GeminiStrategy
|
|
from coati.trainer.strategies.colossalai import LowLevelZeroStrategy
|
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
|
|
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
|
|
|
|
|
|
def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
|
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
|
|
attention_mask = torch.ones_like(input_ids)
|
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
def gather_and_equal(tensor: torch.Tensor) -> bool:
|
|
world_size = dist.get_world_size()
|
|
outputs = [torch.empty_like(tensor) for _ in range(world_size)]
|
|
dist.all_gather(outputs, tensor.contiguous())
|
|
for t in outputs[1:]:
|
|
if not torch.equal(outputs[0], t):
|
|
return False
|
|
return True
|
|
|
|
|
|
def make_and_consume_experience(strategy):
|
|
EXPERIENCE_BATCH_SIZE = 4
|
|
SAMPLE_BATCH_SIZE = 2
|
|
|
|
if strategy == "ddp":
|
|
strategy = DDPStrategy()
|
|
elif strategy == "colossalai-zero2":
|
|
strategy = LowLevelZeroStrategy()
|
|
elif strategy == "colossalai-gemini":
|
|
strategy = GeminiStrategy(placement_policy="cuda")
|
|
else:
|
|
raise ValueError(f'Unsupported strategy "{strategy}"')
|
|
|
|
actor = GPTActor(config=GPT_CONFIG).cuda()
|
|
critic = GPTCritic(config=GPT_CONFIG).cuda()
|
|
|
|
initial_model = deepcopy(actor)
|
|
reward_model = RewardModel(deepcopy(critic.model)).cuda()
|
|
|
|
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
|
|
data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
|
|
|
# experience of all ranks should be the same
|
|
for _ in range(2):
|
|
data = get_data(EXPERIENCE_BATCH_SIZE)
|
|
assert gather_and_equal(data["input_ids"])
|
|
assert gather_and_equal(data["attention_mask"])
|
|
experience = experience_maker.make_experience(
|
|
**data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256
|
|
)
|
|
assert gather_and_equal(experience.sequences)
|
|
assert gather_and_equal(experience.action_log_probs)
|
|
assert gather_and_equal(experience.values)
|
|
assert gather_and_equal(experience.reward)
|
|
assert gather_and_equal(experience.advantages)
|
|
assert gather_and_equal(experience.action_mask)
|
|
assert gather_and_equal(experience.attention_mask)
|
|
data_buffer.append(experience)
|
|
|
|
# data buffer's data should be the same
|
|
buffer_size = torch.tensor([len(data_buffer)], device="cuda")
|
|
assert gather_and_equal(buffer_size)
|
|
for item in data_buffer.items:
|
|
assert gather_and_equal(item.sequences)
|
|
assert gather_and_equal(item.action_log_probs)
|
|
assert gather_and_equal(item.values)
|
|
assert gather_and_equal(item.reward)
|
|
assert gather_and_equal(item.advantages)
|
|
assert gather_and_equal(item.action_mask)
|
|
assert gather_and_equal(item.attention_mask)
|
|
|
|
# dataloader of each rank should have the same size and different batch
|
|
dataloader = strategy.setup_dataloader(data_buffer)
|
|
dataloader_size = torch.tensor([len(dataloader)], device="cuda")
|
|
assert gather_and_equal(dataloader_size)
|
|
for experience in dataloader:
|
|
assert not gather_and_equal(experience.sequences)
|
|
assert not gather_and_equal(experience.action_log_probs)
|
|
assert not gather_and_equal(experience.values)
|
|
assert not gather_and_equal(experience.reward)
|
|
assert not gather_and_equal(experience.advantages)
|
|
# action mask and attention mask may be same
|
|
|
|
|
|
def run_dist(rank, world_size, port, strategy):
|
|
os.environ["RANK"] = str(rank)
|
|
os.environ["LOCAL_RANK"] = str(rank)
|
|
os.environ["WORLD_SIZE"] = str(world_size)
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = str(port)
|
|
make_and_consume_experience(strategy)
|
|
|
|
|
|
@pytest.mark.dist
|
|
@pytest.mark.parametrize("world_size", [2])
|
|
@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"])
|
|
@rerun_if_address_is_in_use()
|
|
def test_experience(world_size, strategy):
|
|
spawn(run_dist, world_size, strategy=strategy)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_experience(2, "colossalai")
|