From 5ab0dc8dc27dd106106335b8041ec1c4b9d18129 Mon Sep 17 00:00:00 2001 From: li126com Date: Wed, 27 Sep 2023 21:19:05 +0800 Subject: [PATCH] pp test --- new_test.py | 99 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 27 deletions(-) diff --git a/new_test.py b/new_test.py index bf8ae16..f0cee2c 100644 --- a/new_test.py +++ b/new_test.py @@ -45,35 +45,46 @@ import torch.distributed as dist class MlpModel(nn.Module): - def __init__(self, start, end): + def __init__(self, start, end, type=None): super().__init__() self.part = [start , end] self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end -start)]) + self.type = type + if gpc.is_first_rank(ParallelMode.PIPELINE): + print(f'{gpc.get_global_rank()}: self.part={self.part}', flush=True) def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): - print(gpc.get_global_rank(), 'hidden_states:', hidden_states, flush=True) - if self.part[0] != 0: + # print(gpc.get_global_rank(), 'hidden_states:', hidden_states, flush=True) + if self.type != 'torch' and not gpc.is_first_rank(ParallelMode.PIPELINE): input_ids = hidden_states - print(f'pp stage: {gpc.get_local_rank(ParallelMode.PIPELINE)} MLP {self.part} fwd:', input_ids.shape, flush=True) - print(gpc.get_global_rank(), 'len_blocsk:', len(self.blocks), flush=True) - current_device = torch.cuda.current_device() - print(gpc.get_global_rank(), 'current_device:', current_device, flush=True) - input_ids = input_ids.to(current_device) - print(gpc.get_global_rank(), 'mlp_input_data:', input_ids, input_ids.shape, type(input_ids), flush=True) - x = self.blocks[0](input_ids) + self.blocks[1](input_ids) - print(gpc.get_global_rank(), 'mlp_output_data:', x, x.shape, flush=True) - return x + # print(f'pp stage: {gpc.get_local_rank(ParallelMode.PIPELINE)} MLP {self.part} fwd:', input_ids.shape, flush=True) + # print(gpc.get_global_rank(), 'len_blocsk:', len(self.blocks), flush=True) + # current_device = torch.cuda.current_device() + # print(gpc.get_global_rank(), 'current_device:', current_device, flush=True) + # input_ids = input_ids.to(current_device) + # print(gpc.get_global_rank(), 'mlp_input_data:', input_ids, input_ids.shape, type(input_ids), flush=True) + for i in range(self.part[1] - self.part[0]): + input_ids = self.blocks[i](input_ids) + return input_ids + # x = self.blocks[0](input_ids) + # x = self.blocks[0](x) + # print(gpc.get_global_rank(), 'mlp_output_data:', x, x.shape, flush=True) + # return x config = Config( dict( + HIDDEN_SIZE=8, + SEQ_LEN=8, gradient_handler=[dict(type="PipelineSharedModuleGradientHandler")], - HIDDEN_SIZE=4, - parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1), + parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=True), sequence_parallel=False, tensor=1), model_type="INTERNLM", data=dict(seq_len=8, micro_num=16, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999), model=dict( dtype=torch.bfloat16, + num_chunks=2, + hidden_size=8, + use_flash_attn=True, ), resume_tb_folder="", tensorboard_folder="", @@ -220,11 +231,12 @@ def exam_pipeline_parallel(args): seed_all(1024) dtype=gpc.config.model["dtype"] - # torch_model = MlpModel().to(device) + # pp_model = copy.deepcopy(torch_model).to(dtype) - pp_model = _build_generic_model_1d(num_layers=16, num_chunks=1) + pp_model = _build_generic_model_1d(num_layers=16, num_chunks=gpc.config.model.num_chunks) pp_model = pp_model.to(dtype) - print(pp_model, flush=True) + print(gpc.get_global_rank(), 'pp_model', pp_model) + scheduler_hooks = [ SchedulerMetricHook( @@ -235,14 +247,26 @@ def exam_pipeline_parallel(args): micro_num = gpc.config.data.micro_num seq_len = gpc.config.data.seq_len gpc.config.NUM_MICRO_BATCHES = micro_num - scheduler = PipelineScheduler( - data_process_func=None, + + communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False) + print(f'communication_overlap={communication_overlap}') + scheduler = InterleavedPipelineScheduler( num_microbatches=micro_num, + num_chunks=gpc.config.model.num_chunks, dtype=gpc.config.model["dtype"], - tensor_shape=None, + tensor_shape=get_tensor_shape(), scatter_gather_tensors=False, scheduler_hooks=scheduler_hooks, + communication_overlap=communication_overlap, ) + # scheduler = PipelineScheduler( + # data_process_func=None, + # num_microbatches=micro_num, + # dtype=dtype, + # tensor_shape=None, + # scatter_gather_tensors=False, + # scheduler_hooks=scheduler_hooks, + # ) print(f"gpc.config.hybrid_zero_optimizer: {gpc.config.hybrid_zero_optimizer}", flush=True) # optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model) @@ -256,7 +280,6 @@ def exam_pipeline_parallel(args): # eps=1e-8, # )) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model) - criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=0) engine = Engine( model=pp_model, @@ -277,10 +300,12 @@ def exam_pipeline_parallel(args): for _ in range(micro_num): x_list.append([i for i in range(seq_len)]) y_list.append([i for i in range(seq_len)]) + torch_xs = torch.tensor(x_list).to(device).to(torch.float32) + torch_ys = torch.tensor(y_list).to(device).to(torch.float32) xs = torch.tensor(x_list).to(device).to(dtype) yx = torch.tensor(y_list).to(device).to(dtype) - xs.requires_grad_() - yx.requires_grad_() + # xs.requires_grad_() + # yx.requires_grad_() print(xs.shape, yx.shape, flush=True) input_list = [{'input_ids':xs}, yx] @@ -293,15 +318,35 @@ def exam_pipeline_parallel(args): # output = torch_model(input) # print(output) print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'start schedule', flush=True) - _, _, loss = scheduler.forward_backward_step(engine, input_list, forward_only=False, return_loss=True, return_output_label=False) + output, label, loss = scheduler.forward_backward_step(engine, input_list, forward_only=False, return_loss=True, return_output_label=True) print('local_rank:', gpc.get_local_rank(ParallelMode.PIPELINE), 'end schedule', flush=True) - dist.barrier() + #dist.barrier() torch.cuda.synchronize() engine.step() torch.cuda.synchronize() - # torch_output = torch_model(input_ids=torch_input) - # torch_loss = criterion(torch_output, torch_label).unsqueeze(0) + + if gpc.is_last_rank(ParallelMode.PIPELINE): + print('torch begin') + torch_model = MlpModel(0, 16, 'torch').to(device) + # torch_model = DDP(torch_model, static_graph=True) + print(gpc.get_global_rank(), 'torch_model', torch_model) + torch_optimizer = torch.optim.AdamW( + params=[{"params": torch_model.parameters(), "weight_decay": config.adam.weight_decay}], + lr=config.adam.lr, + betas=(config.adam.adam_beta1, config.adam.adam_beta2), + eps=config.adam.adam_eps, + ) + torch_output = torch_model(input_ids=torch_xs) + criterion = MyLoss().to(torch.float32) + torch_loss = criterion(torch_output, torch_ys) / micro_num + torch_loss.backward() + torch_optimizer.step() + print(gpc.get_global_rank(), 'test_torch:', 'torch_output:', torch_output, 'torch_loss:', torch_loss) + print(gpc.get_global_rank(), 'test_pp:', 'output:', output, 'label:', label, 'loss:', loss) + loose_close(torch_output, output, dtype=dtype) + loose_close(torch_loss, loss[0], dtype=dtype) + print(gpc.get_global_rank(), 'assert_ok') # if rank == 0: # print('loss:', loss)