diff --git a/colossalai/pipeline/rpc/PipelineBase.py b/colossalai/pipeline/rpc/PipelineBase.py index 6c3d0afe5..9bb548ff6 100644 --- a/colossalai/pipeline/rpc/PipelineBase.py +++ b/colossalai/pipeline/rpc/PipelineBase.py @@ -9,6 +9,7 @@ import torch.distributed.rpc as rpc from torch.futures import Future from torch._C._distributed_rpc import PyRRef from torch import autograd +from torch import optim from tqdm import tqdm from colorama import Back, Style @@ -43,8 +44,7 @@ def tensor_shape_list(tensors): class Phase(Enum): FORWARD = 0 BACKWARD = 1 - ACCUM_GRAD = 2 - SYNC = 3 + UPDATE = 2 class UniqueKey: @@ -440,8 +440,6 @@ class Worker: if isinstance(input_node, torch.Tensor): consume_result.append(input_node.grad) - elif phase == Phase.SYNC: - pass else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -478,6 +476,18 @@ class Worker: 'work loop', 'green') work_item.output.set_result(consume_result) + def initialize_optimizer(self, optimizer_class: type, **kwargs): + self.optimizer: optim.Optimizer = optimizer_class(self.module_partition.parameters(), **kwargs) + + def step(self): + assert hasattr(self, "optimizer"), "call initialize_optimizer first before you call step!" + self.work_list.clear() + self.output_list.clear() + self.microbatch_id_to_backward_cache.clear() + + self.optimizer.step() + self.optimizer.zero_grad() + # TODO # 1. chunk @@ -617,6 +627,24 @@ class PipelineEngineBase(ABC, nn.Module): first_stage_worker.rpc_sync().get_output_by_key(key) return forward_result + def initialize_optimizer(self, optimizer_class: type, **kwargs): + actual_stage_num = self._get_actual_stage_num() + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + worker_rref.rpc_sync().initialize_optimizer(optimizer_class, **kwargs) + + def step(self): + step_futs: List[Future] = [] + actual_stage_num = self._get_actual_stage_num() + for pp_rank in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[pp_rank] + fut = worker_rref.rpc_async().step() + step_futs.append(fut) + + # wait for all optimizers + for fut in step_futs: + fut.wait() + class FillDrainPipelineEngine(PipelineEngineBase): diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py new file mode 100644 index 000000000..e7caea0bd --- /dev/null +++ b/tests/test_pipeline/rpc_test_utils.py @@ -0,0 +1,84 @@ +import os +import argparse + +import torch +from torch import nn +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc +from torch import autograd +from torch.optim import SGD, Adam, RMSprop, Optimizer +from colorama import Back, Style + +from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.testing import assert_close + + +def color_debug(text, prefix=' ', color='blue'): + color = color.upper() + print(getattr(Back, color), prefix, Style.RESET_ALL, text) + + +class RpcTestModel(nn.Module): + + def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: + super().__init__() + self.rank = stage_id + self.is_last_rank = stage_id == actual_stage_num - 1 + self.linear_name = f'linear_{stage_id}' + if stage_id == 0: + setattr(self, self.linear_name, nn.Linear(feat_num, h)) + elif stage_id == actual_stage_num - 1: + setattr(self, self.linear_name, nn.Linear(h, 1)) + else: + setattr(self, self.linear_name, nn.Linear(h, h)) + + def forward(self, x) -> torch.Tensor: + linear: nn.Module = getattr(self, self.linear_name) + out: torch.Tensor = linear(x) + + if self.is_last_rank: + out = out.sum() + return out + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--num_microbatches', type=int, default=2) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--use_checkpoint', action='store_true') + parser.add_argument('--use_interleave', action='store_true') + parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=str, default=128) + return parser.parse_args() + + +def run_worker(rank, args, master_func): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + # config rpc + # if cuda is used, set_device_map is a must is configured + # for cuda is not supported in torch rpc by default + options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads) + + world_size = args.world_size + for rank_idx in range(world_size): + options.set_device_map(f'work{rank_idx}', {rank: rank_idx}) + + rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + master_func(args) + # barrier here + rpc.shutdown() + + +def rpc_run(args, master_func): + world_size = args.world_size + assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!" + mp.spawn(run_worker, args=(args, master_func), nprocs=world_size) diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py new file mode 100644 index 000000000..12db694fa --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -0,0 +1,85 @@ +import os +import argparse + +import torch +from torch import nn +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc +from torch import autograd +from torch.optim import SGD, Adam, RMSprop, Optimizer +from colorama import Back, Style + +from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel + + +def run_master(args): + torch.manual_seed(100) + + device = args.device + stage_num = args.world_size + chunk = args.chunk + actual_stage_num = stage_num * chunk + use_interleave = args.use_interleave + use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches + optimizer_class = globals()[args.optimizer] + + lr = 1e-3 + + sample_num = 1024 + feat_num = 100 + h = 100 + batch_size = 1024 + + assert sample_num % batch_size == 0 + batch_num = sample_num // batch_size + + input_sample = torch.randn((sample_num, feat_num), device=device) + + module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] + + engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + use_interleave=use_interleave, + checkpoint=use_checkpoint) + + engine.initialize_optimizer(optimizer_class, lr=lr) + + _ = engine.forward_backward(input_sample) + engine.step() + + cuda_rpc_result = [] + single_result = [] + actual_stage_num = engine._get_actual_stage_num() + + # compute parameters after updating in cuda rpc + parameters = engine.remote_parameters() + for stage_id in range(actual_stage_num): + for p in parameters[stage_id]: + cuda_rpc_result.append(p) + + # compute forward result and backward grad of parameters just in rank_0 + test_model = nn.Sequential(*module_partitions).to(device) + optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) + input_sample = input_sample.requires_grad_() + out_val = test_model(input_sample).sum() + autograd.backward(out_val) + optimizer.step() + optimizer.zero_grad() + + for p in test_model.parameters(): + single_result.append(p) + + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) + + +if __name__ == "__main__": + args = parse_args() + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py index 6608a5c5a..9dc19f13d 100644 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -7,32 +7,10 @@ import torch.multiprocessing as mp import torch.distributed.rpc as rpc from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine +from rpc_test_utils import rpc_run, parse_args, RpcTestModel -class TestModel(nn.Module): - - def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: - super().__init__() - self.rank = stage_id - self.is_last_rank = stage_id == actual_stage_num - 1 - self.linear_name = f'linear_{stage_id}' - if stage_id == 0: - setattr(self, self.linear_name, nn.Linear(feat_num, h)) - elif stage_id == actual_stage_num - 1: - setattr(self, self.linear_name, nn.Linear(h, 1)) - else: - setattr(self, self.linear_name, nn.Linear(h, h)) - - def forward(self, x) -> torch.Tensor: - linear: nn.Module = getattr(self, self.linear_name) - out: torch.Tensor = linear(x) - - if self.is_last_rank: - out = out.sum() - return out - - -def run_main(args): +def run_master(args): torch.manual_seed(100) device = args.device @@ -53,7 +31,7 @@ def run_main(args): input_sample = torch.randn((sample_num, feat_num), device=device) - module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] + module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = OneFOneBPipelineEngine(module_partitions=module_partitions, stage_num=stage_num, @@ -66,44 +44,6 @@ def run_main(args): _ = engine.forward_backward(input_sample) -def run_worker(rank, args): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port - - # config rpc - # if cuda is used, set_device_map is a must is configured - # for cuda is not supported in torch rpc by default - options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads) - - world_size = args.world_size - for rank_idx in range(world_size): - options.set_device_map(f'work{rank_idx}', {rank: rank_idx}) - - rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) - - # in rpc mode, only rank 0 is needed to be coded - if rank == 0: - run_main(args) - # barrier here - rpc.shutdown() - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--device', type=str, default='cuda') - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--use_interleave', action='store_true') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=str, default=128) - return parser.parse_args() - - if __name__ == "__main__": args = parse_args() - world_size = args.world_size - assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!" - mp.spawn(run_worker, args=(args,), nprocs=world_size) + rpc_run(args, run_master) diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py index 0c5f75a12..c7a439c37 100644 --- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -9,37 +9,11 @@ from torch import autograd from colorama import Back, Style from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.testing import assert_close +from rpc_test_utils import rpc_run, parse_args, RpcTestModel -def color_debug(text, prefix=' ', color='blue'): - color = color.upper() - print(getattr(Back, color), prefix, Style.RESET_ALL, text) - - -class TestModel(nn.Module): - - def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: - super().__init__() - self.rank = stage_id - self.is_last_rank = stage_id == actual_stage_num - 1 - self.linear_name = f'linear_{stage_id}' - if stage_id == 0: - setattr(self, self.linear_name, nn.Linear(feat_num, h)) - elif stage_id == actual_stage_num - 1: - setattr(self, self.linear_name, nn.Linear(h, 1)) - else: - setattr(self, self.linear_name, nn.Linear(h, h)) - - def forward(self, x) -> torch.Tensor: - linear: nn.Module = getattr(self, self.linear_name) - out: torch.Tensor = linear(x) - - if self.is_last_rank: - out = out.sum() - return out - - -def run_main(args): +def run_master(args): torch.manual_seed(100) device = args.device @@ -48,6 +22,7 @@ def run_main(args): actual_stage_num = stage_num * chunk use_interleave = args.use_interleave use_checkpoint = args.use_checkpoint + num_microbatches = args.num_microbatches sample_num = 1024 feat_num = 100 @@ -57,11 +32,9 @@ def run_main(args): assert sample_num % batch_size == 0 batch_num = sample_num // batch_size - num_microbatches = stage_num * 1 - input_sample = torch.randn((sample_num, feat_num), device=device) - module_partitions = [TestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] + module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = OneFOneBPipelineEngine(module_partitions=module_partitions, stage_num=stage_num, @@ -77,74 +50,27 @@ def run_main(args): single_result = [] actual_stage_num = engine._get_actual_stage_num() - # color_debug('cuda rpc forward', 'Test') - # print(sum(forward_result[0])) - cuda_rpc_result.append(sum(forward_result[0]).item()) - # color_debug('cuda rpc backward', 'Test') + # compute forward result and backward grad of parameters in cuda rpc + cuda_rpc_result.append(sum(forward_result[0])) grad = engine.remote_grad() for stage_id in range(actual_stage_num): for p in grad[stage_id]: - # print(p.sum()) - cuda_rpc_result.append(p.sum().item()) + cuda_rpc_result.append(p) + # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential(*module_partitions).to(device) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() autograd.backward(out_val) - # color_debug('single forward', 'Test') - # print(out_val) - single_result.append(out_val.item()) - # color_debug('single backward', 'Test') + single_result.append(out_val) for p in test_model.parameters(): - # print(p.grad.sum()) - single_result.append(p.grad.sum().item()) + single_result.append(p.grad) - cuda_rpc_result = torch.tensor(cuda_rpc_result) - single_result = torch.tensor(single_result) - distance = (cuda_rpc_result - single_result).abs().sum().item() - kappa = round(distance / actual_stage_num, 5) - assert kappa < 0.01, f"kappa({kappa}) is too big, PP result may be incorrect!" - - -def run_worker(rank, args): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port - - # config rpc - # if cuda is used, set_device_map is a must is configured - # for cuda is not supported in torch rpc by default - options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads) - - world_size = args.world_size - for rank_idx in range(world_size): - options.set_device_map(f'work{rank_idx}', {rank: rank_idx}) - - rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) - - # in rpc mode, only rank 0 is needed to be coded - if rank == 0: - run_main(args) - # barrier here - rpc.shutdown() - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--use_interleave', action='store_true') - parser.add_argument('--device', type=str, default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=str, default=128) - return parser.parse_args() + assert len(cuda_rpc_result) == len(single_result) + for r_c, r_s in zip(cuda_rpc_result, single_result): + assert_close(r_c, r_s, 0.001, 0.001) if __name__ == "__main__": args = parse_args() - world_size = args.world_size - assert args.num_microbatches >= args.world_size, "num_microbatches cannot be fewer than world_size!" - assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!" - mp.spawn(run_worker, args=(args,), nprocs=world_size) + rpc_run(args, run_master)