2023-08-16 11:29:03 +00:00
|
|
|
import copy
|
|
|
|
from functools import partial
|
|
|
|
from types import MethodType
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
|
|
from colossalai.interface import OptimizerWrapper
|
|
|
|
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|
|
|
from colossalai.testing.random import seed_all
|
|
|
|
|
|
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(MlpModel, self).__init__()
|
|
|
|
self.linear1 = nn.Linear(4, 8)
|
|
|
|
self.linear2 = nn.Linear(8, 8)
|
|
|
|
self.linear3 = nn.Linear(8, 8)
|
|
|
|
self.linear4 = nn.Linear(8, 8)
|
|
|
|
self.linear5 = nn.Linear(8, 8)
|
|
|
|
self.linear6 = nn.Linear(8, 8)
|
|
|
|
self.linear7 = nn.Linear(8, 8)
|
|
|
|
self.linear8 = nn.Linear(8, 4)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.linear1(x)
|
|
|
|
x = self.linear2(x)
|
|
|
|
x = self.linear3(x)
|
|
|
|
x = self.linear4(x)
|
|
|
|
x = self.linear5(x)
|
|
|
|
x = self.linear6(x)
|
|
|
|
x = self.linear7(x)
|
|
|
|
x = self.linear8(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def pp_linear_fwd(
|
|
|
|
forward,
|
|
|
|
data: torch.Tensor = None,
|
|
|
|
input_obj: torch.Tensor = None,
|
|
|
|
stage_mgr: PipelineStageManager = None,
|
|
|
|
num_chunks: int = None,
|
|
|
|
model_chunk_id: int = None,
|
|
|
|
):
|
2023-08-16 11:29:03 +00:00
|
|
|
if stage_mgr.is_first_stage() and model_chunk_id == 0:
|
2023-09-19 06:20:26 +00:00
|
|
|
return {"input_obj": forward(data)}
|
2023-08-16 11:29:03 +00:00
|
|
|
elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1:
|
|
|
|
return forward(input_obj)
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
return {"input_obj": forward(input_obj)}
|
2023-08-16 11:29:03 +00:00
|
|
|
|
|
|
|
|
|
|
|
@parameterize("num_micro_batches", [4, 8, 12])
|
|
|
|
def examine_pp(num_micro_batches):
|
|
|
|
"""
|
|
|
|
This test is to examine the correctness of interleaved 1F1B, compared with torch.
|
|
|
|
Be aware it contains some hardcodes.
|
|
|
|
"""
|
|
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
local_rank = torch.distributed.get_rank()
|
|
|
|
seed_all(1453)
|
|
|
|
|
|
|
|
NUM_MICRO_BATCHS = num_micro_batches
|
|
|
|
BATCH_SIZE = num_micro_batches
|
|
|
|
NUM_CHUNKS = 2
|
|
|
|
|
|
|
|
# create model
|
|
|
|
torch_model = MlpModel().cuda()
|
|
|
|
|
|
|
|
pp_model = copy.deepcopy(torch_model).cuda()
|
|
|
|
|
|
|
|
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
|
|
|
pg_mesh = ProcessGroupMesh(1, world_size, 1)
|
|
|
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True)
|
|
|
|
schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager)
|
|
|
|
|
|
|
|
sharded_model = torch.nn.ModuleList()
|
|
|
|
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
|
|
|
|
if idx % (world_size) == local_rank:
|
|
|
|
sub_model._forward = sub_model.forward
|
|
|
|
sub_model.forward = MethodType(
|
2023-09-19 06:20:26 +00:00
|
|
|
partial(
|
|
|
|
pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model)
|
|
|
|
),
|
|
|
|
sub_model._forward,
|
|
|
|
)
|
2023-08-16 11:29:03 +00:00
|
|
|
sharded_model.append(sub_model.cuda())
|
|
|
|
|
|
|
|
# create optimizer
|
|
|
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
|
|
|
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))
|
|
|
|
|
|
|
|
# create
|
|
|
|
seed_all(1453)
|
|
|
|
if local_rank == 0:
|
|
|
|
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
|
|
|
|
else:
|
|
|
|
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
|
|
|
|
torch.distributed.all_reduce(input_list[0])
|
|
|
|
|
|
|
|
criterion = lambda x, y: torch.mean(x)
|
|
|
|
|
|
|
|
# forward and backward
|
|
|
|
torch_output = torch_model(input_list[0])
|
|
|
|
torch_loss = criterion(torch_output, _)
|
|
|
|
torch_loss.backward()
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
pp_ret = schedule.forward_backward_step(
|
|
|
|
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
|
|
|
|
)
|
2023-08-16 11:29:03 +00:00
|
|
|
|
|
|
|
# check loss
|
|
|
|
if stage_manager.is_last_stage():
|
2023-09-19 06:20:26 +00:00
|
|
|
assert torch.allclose(torch_loss, pp_ret["loss"])
|
2023-08-16 11:29:03 +00:00
|
|
|
|
|
|
|
# check gradients
|
|
|
|
torch_grad = []
|
|
|
|
for torch_p in torch_model.parameters():
|
|
|
|
torch_grad.append(torch_p.grad.data)
|
|
|
|
|
|
|
|
for idx, pp_p in enumerate(sharded_model.parameters()):
|
|
|
|
if idx < 2:
|
|
|
|
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
|
|
|
|
else:
|
|
|
|
assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data)
|
|
|
|
|
|
|
|
# step
|
|
|
|
torch_optimizer.step()
|
|
|
|
pp_optimizer.step()
|
|
|
|
|
|
|
|
# check updated param
|
|
|
|
torch_param = []
|
|
|
|
for torch_p in torch_model.parameters():
|
|
|
|
torch_param.append(torch_p.data)
|
|
|
|
for idx, pp_p in enumerate(sharded_model.parameters()):
|
|
|
|
if idx < 2:
|
|
|
|
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
|
|
|
|
else:
|
|
|
|
assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data)
|
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
2023-09-19 06:20:26 +00:00
|
|
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
2023-08-16 11:29:03 +00:00
|
|
|
examine_pp()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_pp():
|
|
|
|
spawn(run_dist, 4)
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if __name__ == "__main__":
|
2023-08-16 11:29:03 +00:00
|
|
|
test_pp()
|