2023-03-28 12:25:36 +00:00
|
|
|
import os
|
|
|
|
import tempfile
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from coati.models.gpt import GPTActor
|
2023-06-13 05:31:56 +00:00
|
|
|
from coati.models.utils import calc_action_log_probs
|
2023-08-02 02:17:36 +00:00
|
|
|
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
2023-03-28 12:25:36 +00:00
|
|
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
|
|
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
2023-04-06 06:51:35 +00:00
|
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
|
|
|
|
|
|
|
|
|
|
|
|
def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
2023-08-02 02:17:36 +00:00
|
|
|
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda")
|
2023-03-28 12:25:36 +00:00
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8):
|
2023-08-02 02:17:36 +00:00
|
|
|
data = get_data(batch_size)
|
|
|
|
action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool)
|
|
|
|
actor_output = actor(data["input_ids"], data["attention_mask"])
|
|
|
|
action_log_probs = calc_action_log_probs(actor_output, data["input_ids"], action_mask.size(1))
|
|
|
|
loss = action_log_probs.sum()
|
|
|
|
strategy.backward(loss, actor, actor_optim)
|
|
|
|
strategy.optimizer_step(actor_optim)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
2023-08-02 02:17:36 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def run_test_checkpoint(strategy_name: str, shard: bool):
|
2023-08-02 02:17:36 +00:00
|
|
|
if strategy_name == "ddp":
|
2023-03-28 12:25:36 +00:00
|
|
|
strategy = DDPStrategy()
|
2023-08-02 02:17:36 +00:00
|
|
|
elif strategy_name == "colossalai_gemini":
|
|
|
|
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
|
|
|
elif strategy_name == "colossalai_zero2":
|
|
|
|
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
2023-03-28 12:25:36 +00:00
|
|
|
else:
|
2023-08-02 02:17:36 +00:00
|
|
|
raise ValueError(f"Unsupported strategy '{strategy_name}'")
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
with strategy.model_init_context():
|
|
|
|
actor = GPTActor(config=GPT_CONFIG).cuda()
|
|
|
|
actor_optim = HybridAdam(actor.parameters())
|
|
|
|
actor, actor_optim = strategy.prepare((actor, actor_optim))
|
|
|
|
|
2023-08-02 02:17:36 +00:00
|
|
|
train_step(strategy, actor, actor_optim)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
|
|
|
|
|
|
|
|
with ctx as dirname:
|
|
|
|
rank0_dirname = [dirname]
|
|
|
|
dist.broadcast_object_list(rank0_dirname)
|
|
|
|
rank0_dirname = rank0_dirname[0]
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
|
2023-08-02 02:17:36 +00:00
|
|
|
strategy.save_model(actor, model_path, only_rank0=not shard)
|
2023-09-19 06:20:26 +00:00
|
|
|
optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
|
2023-08-02 02:17:36 +00:00
|
|
|
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
|
2023-03-28 12:25:36 +00:00
|
|
|
dist.barrier()
|
|
|
|
|
|
|
|
strategy.load_model(actor, model_path, strict=False)
|
|
|
|
strategy.load_optimizer(actor_optim, optim_path)
|
|
|
|
dist.barrier()
|
|
|
|
|
2023-08-02 02:17:36 +00:00
|
|
|
train_step(strategy, actor, actor_optim)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool):
|
2023-08-02 02:17:36 +00:00
|
|
|
os.environ["RANK"] = str(rank)
|
|
|
|
os.environ["LOCAL_RANK"] = str(rank)
|
|
|
|
os.environ["WORLD_SIZE"] = str(world_size)
|
|
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
|
os.environ["MASTER_PORT"] = str(port)
|
|
|
|
run_test_checkpoint(strategy_name, shard)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
2023-08-02 02:17:36 +00:00
|
|
|
@pytest.mark.parametrize("world_size", [4])
|
|
|
|
@pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"])
|
|
|
|
@pytest.mark.parametrize("shard", [False, True])
|
2023-03-28 12:25:36 +00:00
|
|
|
@rerun_if_address_is_in_use()
|
2023-09-19 06:20:26 +00:00
|
|
|
def test_checkpoint(world_size: int, strategy_name: str, shard: bool):
|
|
|
|
spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard)
|
2023-03-28 12:25:36 +00:00
|
|
|
|
|
|
|
|
2023-08-02 02:17:36 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
test_checkpoint(2, "colossalai_gemini", shard=False)
|