From d3c6cd66f338e5c866b997d4add9cf9b1a8be351 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Mon, 31 Jul 2023 14:49:55 +0800 Subject: [PATCH] [pipeline] add unit test for 1f1b (#4303) * add unit test for 1f1b * polish code * polish code and update ut version * fix --- .../test_schedule/test_oneF_oneB.py | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tests/test_pipeline/test_schedule/test_oneF_oneB.py diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py new file mode 100644 index 000000000..542116a1d --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -0,0 +1,134 @@ +import copy +from functools import partial +from types import MethodType + +import pytest +import torch +import torch.nn as nn + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper +from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing.random import seed_all + + +class MlpModel(nn.Module): + + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(4, 8) + self.linear2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def pp_linear_fwd(forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None): + + if stage_mgr.is_first_stage(): + return {'input_obj': forward(data)} + elif stage_mgr.is_last_stage(): + return forward(input_obj) + else: + return {'input_obj': forward(input_obj)} + + +def examine_pp(): + """ + This test is to examine the correctness of 1F1B, compared with torch. + Be aware it contains some hardcodes. + """ + world_size = torch.distributed.get_world_size() + local_rank = torch.distributed.get_rank() + seed_all(1453) + + NUM_MICRO_BATCHS = 4 + BATCH_SIZE = 4 + + # create models + torch_model = MlpModel().cuda() + + pp_model = copy.deepcopy(torch_model).cuda() + + DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 + pg_mesh = ProcessGroupMesh(1, world_size, 1) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager) + + for idx, (_, sub_model) in enumerate(pp_model.named_children()): + if idx % (world_size) == local_rank: + sharded_model = sub_model.cuda() + + sharded_model._forward = sharded_model.forward + sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) + + # create optimizer + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) + + # create + seed_all(1453) + if stage_manager.is_first_stage(): + input_list = [torch.rand(BATCH_SIZE, 4).cuda()] + else: + input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] + torch.distributed.all_reduce(input_list[0]) + + criterion = lambda x, y: torch.mean(x) + + # forward and backward + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output, _) + torch_loss.backward() + + pp_ret = schedule.forward_backward_step(sharded_model, + pp_optimizer, + iter(input_list), + criterion, + return_loss=True, + return_outputs=True) + + # check loss + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret['loss']) + + # check gradients + torch_grad = [] + for torch_p in torch_model.parameters(): + torch_grad.append(torch_p.grad.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + + # step + torch_optimizer.step() + pp_optimizer.step() + + # check updated param + torch_param = [] + for torch_p in torch_model.parameters(): + torch_param.append(torch_p.data) + for idx, pp_p in enumerate(sharded_model.parameters()): + assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + examine_pp() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_pp(): + spawn(run_dist, 2) + + +if __name__ == '__main__': + test_pp()