import torch from torch import nn from torch import autograd from torch.optim import SGD, Adam, RMSprop, Optimizer from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close from rpc_test_utils import rpc_run, parse_args, RpcTestModel def run_master(args): torch.manual_seed(100) device = args.device stage_num = args.world_size chunk = args.chunk actual_stage_num = stage_num * chunk use_checkpoint = args.use_checkpoint num_microbatches = args.num_microbatches optimizer_class = globals()[args.optimizer] lr = 1e-3 sample_num = 1024 feat_num = 100 h = 100 batch_size = 1024 assert sample_num % batch_size == 0 batch_num = sample_num // batch_size input_sample = torch.randn((sample_num, feat_num), device=device) module_partitions = [RpcTestModel(pp_rank, actual_stage_num, feat_num, h) for pp_rank in range(actual_stage_num)] engine = OneFOneBPipelineEngine(module_partitions=module_partitions, stage_num=stage_num, num_microbatches=num_microbatches, device=device, chunk=chunk, checkpoint=use_checkpoint) engine.initialize_optimizer(optimizer_class, lr=lr) _ = engine.forward_backward(input_sample) cuda_rpc_result = [] single_result = [] actual_stage_num = engine._get_actual_stage_num() # compute parameters after updating in cuda rpc parameters = engine.remote_parameters() for stage_id in range(actual_stage_num): for p in parameters[stage_id]: cuda_rpc_result.append(p) # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential(*module_partitions).to(device) optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() autograd.backward(out_val) optimizer.step() optimizer.zero_grad() for p in test_model.parameters(): single_result.append(p) assert len(cuda_rpc_result) == len(single_result) for r_c, r_s in zip(cuda_rpc_result, single_result): assert_close(r_c, r_s, 0.001, 0.001) if __name__ == "__main__": args = parse_args() rpc_run(args, run_master)