unitest_only_forward (#484)

pull/507/head
jiaxingli 2023-11-16 15:30:57 +08:00 committed by GitHub
parent e8cf27b8c0
commit 4a6987d5e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 22 deletions

View File

@ -274,9 +274,12 @@ def exam_pipeline_parallel(args):
input_list = [{"input_ids": xs}, yx]
# pp forward and backward
output, _, loss = scheduler.forward_backward_step(
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
)
output_list = []
for _ in range(10):
output, _, loss = scheduler.forward_backward_step(
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
)
output_list.append(output)
engine.step()
@ -292,6 +295,11 @@ def exam_pipeline_parallel(args):
eps=config.adam.adam_eps,
)
# check only forward logits
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
# check output
torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102
loose_close(torch_output, output, dtype=dtype)

View File

@ -31,7 +31,15 @@ def check_embedding(args):
# create input
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
result = embedding(input_ids)
output_list = []
for _ in range(10):
result = embedding(input_ids)
output_list.append(result)
# check only forward logits
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
standard_result = torch.tensor(standard_list).to(device)

View File

@ -114,7 +114,6 @@ def check_block(args):
# create input
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
hidden_states = torch.tensor(
@ -130,19 +129,29 @@ def check_block(args):
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
# forward
for _, block in enumerate(blocks):
block = block.to(torch.bfloat16)
block = block.to(device)
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
hid2 = hidden_states
output_list = []
for i in range(10):
hidden_states = hid2
# forward
for _, block in enumerate(blocks):
block = block.to(torch.bfloat16)
block = block.to(device)
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=None,
max_seqlen=max_seqlen,
)
result = hidden_states
output_list.append(result)
# check only forward logits
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
result = hidden_states
standard_result = torch.tensor(
[
[-1.1621, 1.3111, 0.1509, 2.2697],
@ -248,8 +257,16 @@ def check_head(args):
requires_grad=True,
).to(device)
# forward
result = head(hidden_states)
output_list = []
for _ in range(10):
# forward
result = head(hidden_states)
output_list.append(result)
# check only forward logits
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
@ -334,8 +351,16 @@ def check_gather_forward(args):
requires_grad=True,
).to(device)
# forward
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
output_list = []
for _ in range(10):
# forward
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
output_list.append(result)
# check only forward logits
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)

View File

@ -37,7 +37,15 @@ def check_norm(args):
).to(device)
# forward
result = norm(hidden_states.float())
output_list = []
for _ in range(10):
result = norm(hidden_states.float())
output_list.append(result)
# check only forward logits
first_output = output_list[0]
for i in range(1, 10):
assert torch.equal(first_output, output_list[i])
standard = torch.tensor(
[