diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 11f98e1..6857b35 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -47,7 +47,19 @@ class MlpModel(nn.Module): self.linear1 = nn.Linear(4, 8) self.linear2 = nn.Linear(8, 8) self.linear3 = nn.Linear(8, 8) - self.linear4 = nn.Linear(8, 4) + self.linear4 = nn.Linear(8, 8) + self.linear5 = nn.Linear(8, 8) + self.linear6 = nn.Linear(8, 8) + self.linear7 = nn.Linear(8, 8) + self.linear8 = nn.Linear(8, 8) + self.linear9 = nn.Linear(8, 8) + self.linear10 = nn.Linear(8, 8) + self.linear11 = nn.Linear(8, 8) + self.linear12 = nn.Linear(8, 8) + self.linear13 = nn.Linear(8, 8) + self.linear14 = nn.Linear(8, 8) + self.linear15 = nn.Linear(8, 8) + self.linear16 = nn.Linear(8, 4) def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): print('MLP:', input_ids, input_ids.dtype, flush=True) @@ -55,13 +67,26 @@ class MlpModel(nn.Module): input_ids = self.linear2(input_ids) input_ids = self.linear3(input_ids) input_ids = self.linear4(input_ids) + input_ids = self.linear5(input_ids) + input_ids = self.linear6(input_ids) + input_ids = self.linear7(input_ids) + input_ids = self.linear8(input_ids) + input_ids = self.linear9(input_ids) + input_ids = self.linear10(input_ids) + input_ids = self.linear11(input_ids) + input_ids = self.linear12(input_ids) + input_ids = self.linear13(input_ids) + input_ids = self.linear14(input_ids) + input_ids = self.linear15(input_ids) + input_ids = self.linear16(input_ids) return input_ids config = Config( dict( - parallel=dict(zero1=1, pipeline=dict(size=2, interleaved_overlap=False), sequence_parallel=False, tensor=1), + HIDDEN_SIZE=4, + parallel=dict(zero1=1, pipeline=dict(size=8, interleaved_overlap=False), sequence_parallel=False, tensor=1), model_type="INTERNLM", - data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999), + data=dict(seq_len=4, micro_num=4, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999), model=dict( dtype=torch.bfloat16, ), @@ -117,7 +142,7 @@ def build_environment(rank, world_size): os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "12345" + os.environ["MASTER_PORT"] = "44444" torch.cuda.empty_cache() # launcher="torch" internlm.launch_from_torch(config=config, seed=1024) @@ -155,17 +180,28 @@ def seed_all(seed, cuda_deterministic=False): def exam_pipeline_parallel(args): + import os rank, world_size = args - dtype = torch.bfloat16 + dtype = torch.float32 build_environment(rank, world_size) + local_rank = int(os.environ["LOCAL_RANK"]) + print('rank_com:', rank, local_rank) + device = torch.device(f"cuda:{local_rank}") + # print('device_id:', device) + # torch.cuda.set_device(device) seed_all(1024) - torch_model = MlpModel().cuda() + torch_model = MlpModel().to(device) pp_model = copy.deepcopy(torch_model).to(dtype) tensor_shape = get_tensor_shape() + tensor_shape = ( + 4, + 4, + ) + # print('tensor_shape:', tensor_shape) scatter_gather = gpc.is_initialized(ParallelMode.TENSOR) @@ -179,7 +215,8 @@ def exam_pipeline_parallel(args): skip=False ), ] - + + gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num scheduler = PipelineScheduler( data_process_func=None, num_microbatches=gpc.config.data.micro_num, @@ -198,7 +235,7 @@ def exam_pipeline_parallel(args): lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, criterion=criterion, - gradient_handlers=[], + gradient_handlers= [dict(type="PipelineSharedModuleGradientHandler")], clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0), ) @@ -206,10 +243,10 @@ def exam_pipeline_parallel(args): engine.train() engine.zero_grad() - input_list = [{'input_ids':torch.tensor([[0,1,2,3]]).cuda().to(dtype)}, - torch.tensor([[1]]).cuda().to(torch.int64)] - torch_input = torch.tensor([[0,1,2,3]]).cuda().to(torch.float32) - torch_label = torch.tensor([[1]]).cuda().to(torch.int64) + input_list = [{'input_ids':torch.tensor([[0,1,2,3],[0,1,2,3],[0,1,2,3],[0,1,2,3]]).to(device).to(dtype)}, + torch.tensor([[1],[1],[1],[1]]).to(device).to(torch.int64)] + torch_input = torch.tensor([[0,1,2,3]]).to(device).to(torch.float32) + torch_label = torch.tensor([[1]]).to(device).to(torch.int64) # print('label_shape:', input_list[1].shape) # input_list = [{'input_ids':torch.rand(1, 4).cuda()}, torch.rand(1, 4).cuda()] # input = input_list[0]