mirror of https://github.com/InternLM/InternLM
unitest_only_forward (#484)
parent
e8cf27b8c0
commit
4a6987d5e7
|
@ -274,9 +274,12 @@ def exam_pipeline_parallel(args):
|
|||
input_list = [{"input_ids": xs}, yx]
|
||||
|
||||
# pp forward and backward
|
||||
output, _, loss = scheduler.forward_backward_step(
|
||||
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
|
||||
)
|
||||
output_list = []
|
||||
for _ in range(10):
|
||||
output, _, loss = scheduler.forward_backward_step(
|
||||
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
|
||||
)
|
||||
output_list.append(output)
|
||||
|
||||
engine.step()
|
||||
|
||||
|
@ -292,6 +295,11 @@ def exam_pipeline_parallel(args):
|
|||
eps=config.adam.adam_eps,
|
||||
)
|
||||
|
||||
# check only forward logits
|
||||
first_output = output_list[0]
|
||||
for i in range(1, 10):
|
||||
assert torch.equal(first_output, output_list[i])
|
||||
|
||||
# check output
|
||||
torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102
|
||||
loose_close(torch_output, output, dtype=dtype)
|
||||
|
|
|
@ -31,7 +31,15 @@ def check_embedding(args):
|
|||
|
||||
# create input
|
||||
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
|
||||
result = embedding(input_ids)
|
||||
output_list = []
|
||||
for _ in range(10):
|
||||
result = embedding(input_ids)
|
||||
output_list.append(result)
|
||||
|
||||
# check only forward logits
|
||||
first_output = output_list[0]
|
||||
for i in range(1, 10):
|
||||
assert torch.equal(first_output, output_list[i])
|
||||
|
||||
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
|
||||
standard_result = torch.tensor(standard_list).to(device)
|
||||
|
|
|
@ -114,7 +114,6 @@ def check_block(args):
|
|||
# create input
|
||||
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
|
||||
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
|
||||
hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
hidden_states = torch.tensor(
|
||||
|
@ -130,19 +129,29 @@ def check_block(args):
|
|||
|
||||
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
|
||||
|
||||
# forward
|
||||
for _, block in enumerate(blocks):
|
||||
block = block.to(torch.bfloat16)
|
||||
block = block.to(device)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
indexes=indexes,
|
||||
inference_params=None,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hid2 = hidden_states
|
||||
output_list = []
|
||||
for i in range(10):
|
||||
hidden_states = hid2
|
||||
# forward
|
||||
for _, block in enumerate(blocks):
|
||||
block = block.to(torch.bfloat16)
|
||||
block = block.to(device)
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
indexes=indexes,
|
||||
inference_params=None,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
result = hidden_states
|
||||
output_list.append(result)
|
||||
|
||||
# check only forward logits
|
||||
first_output = output_list[0]
|
||||
for i in range(1, 10):
|
||||
assert torch.equal(first_output, output_list[i])
|
||||
|
||||
result = hidden_states
|
||||
standard_result = torch.tensor(
|
||||
[
|
||||
[-1.1621, 1.3111, 0.1509, 2.2697],
|
||||
|
@ -248,8 +257,16 @@ def check_head(args):
|
|||
requires_grad=True,
|
||||
).to(device)
|
||||
|
||||
# forward
|
||||
result = head(hidden_states)
|
||||
output_list = []
|
||||
for _ in range(10):
|
||||
# forward
|
||||
result = head(hidden_states)
|
||||
output_list.append(result)
|
||||
|
||||
# check only forward logits
|
||||
first_output = output_list[0]
|
||||
for i in range(1, 10):
|
||||
assert torch.equal(first_output, output_list[i])
|
||||
|
||||
# check output
|
||||
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
||||
|
@ -334,8 +351,16 @@ def check_gather_forward(args):
|
|||
requires_grad=True,
|
||||
).to(device)
|
||||
|
||||
# forward
|
||||
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
output_list = []
|
||||
for _ in range(10):
|
||||
# forward
|
||||
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
output_list.append(result)
|
||||
|
||||
# check only forward logits
|
||||
first_output = output_list[0]
|
||||
for i in range(1, 10):
|
||||
assert torch.equal(first_output, output_list[i])
|
||||
|
||||
# check output
|
||||
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
||||
|
|
|
@ -37,7 +37,15 @@ def check_norm(args):
|
|||
).to(device)
|
||||
|
||||
# forward
|
||||
result = norm(hidden_states.float())
|
||||
output_list = []
|
||||
for _ in range(10):
|
||||
result = norm(hidden_states.float())
|
||||
output_list.append(result)
|
||||
|
||||
# check only forward logits
|
||||
first_output = output_list[0]
|
||||
for i in range(1, 10):
|
||||
assert torch.equal(first_output, output_list[i])
|
||||
|
||||
standard = torch.tensor(
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue