From 4a6987d5e748e41837fde378bab2a46b4e4dd45a Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Thu, 16 Nov 2023 15:30:57 +0800 Subject: [PATCH] unitest_only_forward (#484) --- tests/test_core/test_pipeline.py | 14 ++++-- tests/test_model/test_embedding.py | 10 ++++- tests/test_model/test_model_internlm.py | 59 ++++++++++++++++++------- tests/test_model/test_norm.py | 10 ++++- 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 72bc52f..ce9dc98 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -274,9 +274,12 @@ def exam_pipeline_parallel(args): input_list = [{"input_ids": xs}, yx] # pp forward and backward - output, _, loss = scheduler.forward_backward_step( - engine, input_list, forward_only=False, return_loss=True, return_output_label=True - ) + output_list = [] + for _ in range(10): + output, _, loss = scheduler.forward_backward_step( + engine, input_list, forward_only=False, return_loss=True, return_output_label=True + ) + output_list.append(output) engine.step() @@ -292,6 +295,11 @@ def exam_pipeline_parallel(args): eps=config.adam.adam_eps, ) + # check only forward logits + first_output = output_list[0] + for i in range(1, 10): + assert torch.equal(first_output, output_list[i]) + # check output torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102 loose_close(torch_output, output, dtype=dtype) diff --git a/tests/test_model/test_embedding.py b/tests/test_model/test_embedding.py index 324ca2b..61203e8 100644 --- a/tests/test_model/test_embedding.py +++ b/tests/test_model/test_embedding.py @@ -31,7 +31,15 @@ def check_embedding(args): # create input input_ids = torch.tensor([[0, 2], [1, 3]]).to(device) - result = embedding(input_ids) + output_list = [] + for _ in range(10): + result = embedding(input_ids) + output_list.append(result) + + # check only forward logits + first_output = output_list[0] + for i in range(1, 10): + assert torch.equal(first_output, output_list[i]) standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]] standard_result = torch.tensor(standard_list).to(device) diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index c002a96..9b6066e 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -114,7 +114,6 @@ def check_block(args): # create input cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16] indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] - hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]] max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() hidden_states = torch.tensor( @@ -130,19 +129,29 @@ def check_block(args): hidden_states = hidden_states.squeeze(0).to(device).requires_grad_() - # forward - for _, block in enumerate(blocks): - block = block.to(torch.bfloat16) - block = block.to(device) - hidden_states = block( - hidden_states, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=None, - max_seqlen=max_seqlen, - ) + hid2 = hidden_states + output_list = [] + for i in range(10): + hidden_states = hid2 + # forward + for _, block in enumerate(blocks): + block = block.to(torch.bfloat16) + block = block.to(device) + hidden_states = block( + hidden_states, + cu_seqlens=cu_seqlens, + indexes=indexes, + inference_params=None, + max_seqlen=max_seqlen, + ) + result = hidden_states + output_list.append(result) + + # check only forward logits + first_output = output_list[0] + for i in range(1, 10): + assert torch.equal(first_output, output_list[i]) - result = hidden_states standard_result = torch.tensor( [ [-1.1621, 1.3111, 0.1509, 2.2697], @@ -248,8 +257,16 @@ def check_head(args): requires_grad=True, ).to(device) - # forward - result = head(hidden_states) + output_list = [] + for _ in range(10): + # forward + result = head(hidden_states) + output_list.append(result) + + # check only forward logits + first_output = output_list[0] + for i in range(1, 10): + assert torch.equal(first_output, output_list[i]) # check output assert torch.allclose(result, standard_result, rtol=rtol, atol=atol) @@ -334,8 +351,16 @@ def check_gather_forward(args): requires_grad=True, ).to(device) - # forward - result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + output_list = [] + for _ in range(10): + # forward + result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + output_list.append(result) + + # check only forward logits + first_output = output_list[0] + for i in range(1, 10): + assert torch.equal(first_output, output_list[i]) # check output assert torch.allclose(result, standard_result, rtol=rtol, atol=atol) diff --git a/tests/test_model/test_norm.py b/tests/test_model/test_norm.py index 4078ef5..10e8681 100644 --- a/tests/test_model/test_norm.py +++ b/tests/test_model/test_norm.py @@ -37,7 +37,15 @@ def check_norm(args): ).to(device) # forward - result = norm(hidden_states.float()) + output_list = [] + for _ in range(10): + result = norm(hidden_states.float()) + output_list.append(result) + + # check only forward logits + first_output = output_list[0] + for i in range(1, 10): + assert torch.equal(first_output, output_list[i]) standard = torch.tensor( [