2023-07-31 06:49:55 +00:00
|
|
|
import copy
|
|
|
|
from functools import partial
|
|
|
|
from types import MethodType
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2023-12-22 02:44:00 +00:00
|
|
|
import torch.distributed as dist
|
2023-07-31 06:49:55 +00:00
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
|
|
from colossalai.interface import OptimizerWrapper
|
|
|
|
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
|
|
|
from colossalai.testing.random import seed_all
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
DIM = 8
|
|
|
|
NUM_LAYER = 8
|
|
|
|
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
|
|
def __init__(self):
|
2023-12-22 02:44:00 +00:00
|
|
|
super().__init__()
|
|
|
|
self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
2023-12-22 02:44:00 +00:00
|
|
|
for layer in self.layers:
|
|
|
|
x = layer(x)
|
2023-07-31 06:49:55 +00:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def pp_linear_fwd(
|
2023-12-22 02:44:00 +00:00
|
|
|
forward,
|
|
|
|
data: torch.Tensor = None,
|
|
|
|
input_obj: torch.Tensor = None,
|
|
|
|
stage_mgr: PipelineStageManager = None,
|
2023-09-19 06:20:26 +00:00
|
|
|
):
|
2023-07-31 06:49:55 +00:00
|
|
|
if stage_mgr.is_first_stage():
|
2023-09-19 06:20:26 +00:00
|
|
|
return {"input_obj": forward(data)}
|
2023-07-31 06:49:55 +00:00
|
|
|
elif stage_mgr.is_last_stage():
|
|
|
|
return forward(input_obj)
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
return {"input_obj": forward(input_obj)}
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
def examine_pp(num_microbatch: int, batch_size: int):
|
2023-07-31 06:49:55 +00:00
|
|
|
"""
|
|
|
|
This test is to examine the correctness of 1F1B, compared with torch.
|
|
|
|
Be aware it contains some hardcodes.
|
|
|
|
"""
|
2023-12-22 02:44:00 +00:00
|
|
|
world_size = dist.get_world_size()
|
|
|
|
dist.get_rank()
|
2023-07-31 06:49:55 +00:00
|
|
|
seed_all(1453)
|
|
|
|
|
|
|
|
# create models
|
|
|
|
torch_model = MlpModel().cuda()
|
|
|
|
|
|
|
|
pp_model = copy.deepcopy(torch_model).cuda()
|
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
pg_mesh = ProcessGroupMesh(world_size)
|
|
|
|
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
|
|
|
|
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
|
|
|
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
sharded_model = torch.nn.ModuleList()
|
|
|
|
num_local_layer = NUM_LAYER // world_size
|
|
|
|
for idx, sub_model in enumerate(pp_model.layers):
|
|
|
|
if idx // num_local_layer == rank:
|
|
|
|
sharded_model.append(sub_model.cuda())
|
|
|
|
assert len(sharded_model) == num_local_layer
|
|
|
|
|
|
|
|
def custom_fwd(self, x):
|
|
|
|
for layer in self._modules.values():
|
|
|
|
x = layer(x)
|
|
|
|
return x
|
2023-07-31 06:49:55 +00:00
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
sharded_model._forward = MethodType(custom_fwd, sharded_model)
|
|
|
|
sharded_model.forward = MethodType(
|
|
|
|
partial(
|
|
|
|
pp_linear_fwd,
|
|
|
|
stage_mgr=stage_manager,
|
|
|
|
),
|
|
|
|
sharded_model._forward,
|
|
|
|
)
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
# create optimizer
|
|
|
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
|
|
|
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))
|
|
|
|
|
|
|
|
# create
|
|
|
|
seed_all(1453)
|
2023-12-22 02:44:00 +00:00
|
|
|
input_list = [torch.rand(batch_size, DIM).cuda()]
|
|
|
|
dist.all_reduce(input_list[0])
|
2023-07-31 06:49:55 +00:00
|
|
|
|
2023-12-22 02:44:00 +00:00
|
|
|
criterion = lambda x, *arg, **kwargs: (x * x).mean()
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
# forward and backward
|
|
|
|
torch_output = torch_model(input_list[0])
|
2023-12-22 02:44:00 +00:00
|
|
|
torch_loss = criterion(torch_output)
|
2023-07-31 06:49:55 +00:00
|
|
|
torch_loss.backward()
|
2023-09-19 06:20:26 +00:00
|
|
|
pp_ret = schedule.forward_backward_step(
|
2024-03-25 04:31:09 +00:00
|
|
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
# check loss
|
|
|
|
if stage_manager.is_last_stage():
|
2023-09-19 06:20:26 +00:00
|
|
|
assert torch.allclose(torch_loss, pp_ret["loss"])
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
# check gradients
|
2023-12-22 02:44:00 +00:00
|
|
|
for i in range(len(sharded_model)):
|
|
|
|
idx = rank * num_local_layer + i
|
|
|
|
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
|
|
|
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
# step
|
|
|
|
torch_optimizer.step()
|
|
|
|
pp_optimizer.step()
|
2023-12-22 02:44:00 +00:00
|
|
|
pp_optimizer.zero_grad()
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
# check updated param
|
2023-12-22 02:44:00 +00:00
|
|
|
for i in range(len(sharded_model)):
|
|
|
|
idx = rank * num_local_layer + i
|
|
|
|
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
|
|
|
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
|
|
|
|
|
|
|
# forward only
|
|
|
|
with torch.no_grad():
|
|
|
|
torch_output = torch_model(input_list[0])
|
|
|
|
torch_loss = criterion(torch_output)
|
|
|
|
|
|
|
|
pp_ret = schedule.forward_backward_step(
|
2024-03-25 04:31:09 +00:00
|
|
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
2023-12-22 02:44:00 +00:00
|
|
|
)
|
|
|
|
if stage_manager.is_last_stage():
|
|
|
|
assert torch.allclose(torch_loss, pp_ret["loss"])
|
|
|
|
|
|
|
|
for layer in sharded_model:
|
|
|
|
if layer.weight.grad is None:
|
|
|
|
assert layer.weight.grad is None and layer.bias.grad is None
|
|
|
|
else:
|
|
|
|
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
|
|
|
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
|
|
|
|
|
|
|
|
|
|
|
def run_dist(
|
|
|
|
rank: int,
|
|
|
|
world_size: int,
|
|
|
|
port: int,
|
|
|
|
num_microbatch: int,
|
|
|
|
batch_size: int,
|
|
|
|
):
|
2023-09-19 06:20:26 +00:00
|
|
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
2023-12-22 02:44:00 +00:00
|
|
|
examine_pp(num_microbatch, batch_size)
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
2024-01-15 07:57:40 +00:00
|
|
|
@pytest.mark.parametrize("num_microbatch", [4, 6])
|
2023-12-22 02:44:00 +00:00
|
|
|
@pytest.mark.parametrize("batch_size", [12])
|
|
|
|
@pytest.mark.parametrize("world_size", [2, 4])
|
2023-07-31 06:49:55 +00:00
|
|
|
@rerun_if_address_is_in_use()
|
2023-12-22 02:44:00 +00:00
|
|
|
def test_pp(num_microbatch: int, batch_size: int, world_size: int):
|
|
|
|
assert NUM_LAYER % world_size == 0
|
|
|
|
spawn(
|
|
|
|
run_dist,
|
|
|
|
world_size,
|
|
|
|
num_microbatch=num_microbatch,
|
|
|
|
batch_size=batch_size,
|
|
|
|
)
|
2023-07-31 06:49:55 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if __name__ == "__main__":
|
2023-12-22 02:44:00 +00:00
|
|
|
test_pp(num_microbatch=4, batch_size=4, world_size=4)
|