import os from copy import deepcopy import pytest import torch import torch.distributed as dist from coati.experience_buffer import NaiveExperienceBuffer from coati.experience_maker import NaiveExperienceMaker from coati.models.base import RewardModel from coati.models.gpt import GPTActor, GPTCritic from coati.trainer.strategies import DDPStrategy, GeminiStrategy from coati.trainer.strategies.colossalai import LowLevelZeroStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config from colossalai.testing import rerun_if_address_is_in_use, spawn GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) def get_data(batch_size: int, seq_len: int = 10) -> dict: input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) def gather_and_equal(tensor: torch.Tensor) -> bool: world_size = dist.get_world_size() outputs = [torch.empty_like(tensor) for _ in range(world_size)] dist.all_gather(outputs, tensor.contiguous()) for t in outputs[1:]: if not torch.equal(outputs[0], t): return False return True def make_and_consume_experience(strategy): EXPERIENCE_BATCH_SIZE = 4 SAMPLE_BATCH_SIZE = 2 if strategy == 'ddp': strategy = DDPStrategy() elif strategy == 'colossalai-zero2': strategy = LowLevelZeroStrategy() elif strategy == 'colossalai-gemini': strategy = GeminiStrategy(placement_policy='cuda') else: raise ValueError(f'Unsupported strategy "{strategy}"') actor = GPTActor(config=GPT_CONFIG).cuda() critic = GPTCritic(config=GPT_CONFIG).cuda() initial_model = deepcopy(actor) reward_model = RewardModel(deepcopy(critic.model)).cuda() experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model) data_buffer = NaiveExperienceBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False) # experience of all ranks should be the same for _ in range(2): data = get_data(EXPERIENCE_BATCH_SIZE) assert gather_and_equal(data['input_ids']) assert gather_and_equal(data['attention_mask']) experience = experience_maker.make_experience(**data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256) assert gather_and_equal(experience.sequences) assert gather_and_equal(experience.action_log_probs) assert gather_and_equal(experience.values) assert gather_and_equal(experience.reward) assert gather_and_equal(experience.advantages) assert gather_and_equal(experience.action_mask) assert gather_and_equal(experience.attention_mask) data_buffer.append(experience) # data buffer's data should be the same buffer_size = torch.tensor([len(data_buffer)], device='cuda') assert gather_and_equal(buffer_size) for item in data_buffer.items: assert gather_and_equal(item.sequences) assert gather_and_equal(item.action_log_probs) assert gather_and_equal(item.values) assert gather_and_equal(item.reward) assert gather_and_equal(item.advantages) assert gather_and_equal(item.action_mask) assert gather_and_equal(item.attention_mask) # dataloader of each rank should have the same size and different batch dataloader = strategy.setup_dataloader(data_buffer) dataloader_size = torch.tensor([len(dataloader)], device='cuda') assert gather_and_equal(dataloader_size) for experience in dataloader: assert not gather_and_equal(experience.sequences) assert not gather_and_equal(experience.action_log_probs) assert not gather_and_equal(experience.values) assert not gather_and_equal(experience.reward) assert not gather_and_equal(experience.advantages) # action mask and attention mask may be same def run_dist(rank, world_size, port, strategy): os.environ['RANK'] = str(rank) os.environ['LOCAL_RANK'] = str(rank) os.environ['WORLD_SIZE'] = str(world_size) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = str(port) make_and_consume_experience(strategy) @pytest.mark.dist @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini']) @rerun_if_address_is_in_use() def test_experience(world_size, strategy): spawn(run_dist, world_size, strategy=strategy) if __name__ == '__main__': test_experience(2, 'colossalai')