mirror of https://github.com/InternLM/InternLM
unitest_only_forward (#484)
parent
e8cf27b8c0
commit
4a6987d5e7
|
@ -274,9 +274,12 @@ def exam_pipeline_parallel(args):
|
||||||
input_list = [{"input_ids": xs}, yx]
|
input_list = [{"input_ids": xs}, yx]
|
||||||
|
|
||||||
# pp forward and backward
|
# pp forward and backward
|
||||||
output, _, loss = scheduler.forward_backward_step(
|
output_list = []
|
||||||
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
|
for _ in range(10):
|
||||||
)
|
output, _, loss = scheduler.forward_backward_step(
|
||||||
|
engine, input_list, forward_only=False, return_loss=True, return_output_label=True
|
||||||
|
)
|
||||||
|
output_list.append(output)
|
||||||
|
|
||||||
engine.step()
|
engine.step()
|
||||||
|
|
||||||
|
@ -292,6 +295,11 @@ def exam_pipeline_parallel(args):
|
||||||
eps=config.adam.adam_eps,
|
eps=config.adam.adam_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check only forward logits
|
||||||
|
first_output = output_list[0]
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert torch.equal(first_output, output_list[i])
|
||||||
|
|
||||||
# check output
|
# check output
|
||||||
torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102
|
torch_output = torch_model(input_ids=torch_xs) # pylint: disable=E1102
|
||||||
loose_close(torch_output, output, dtype=dtype)
|
loose_close(torch_output, output, dtype=dtype)
|
||||||
|
|
|
@ -31,7 +31,15 @@ def check_embedding(args):
|
||||||
|
|
||||||
# create input
|
# create input
|
||||||
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
|
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
|
||||||
result = embedding(input_ids)
|
output_list = []
|
||||||
|
for _ in range(10):
|
||||||
|
result = embedding(input_ids)
|
||||||
|
output_list.append(result)
|
||||||
|
|
||||||
|
# check only forward logits
|
||||||
|
first_output = output_list[0]
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert torch.equal(first_output, output_list[i])
|
||||||
|
|
||||||
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
|
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
|
||||||
standard_result = torch.tensor(standard_list).to(device)
|
standard_result = torch.tensor(standard_list).to(device)
|
||||||
|
|
|
@ -114,7 +114,6 @@ def check_block(args):
|
||||||
# create input
|
# create input
|
||||||
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
|
cu_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32).to(device) # [0, 8, 16]
|
||||||
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
|
indexes = torch.tensor([0, 1, 0, 1]).to(device) # [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
|
||||||
hidden_states = torch.tensor([[0, 3, 2, 1]]).to(device) # [[4, 118, 0, 1, 2, 3, 0, 1, 1, 97, 0, 0, 0, 0, 0, 0]]
|
|
||||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
|
||||||
hidden_states = torch.tensor(
|
hidden_states = torch.tensor(
|
||||||
|
@ -130,19 +129,29 @@ def check_block(args):
|
||||||
|
|
||||||
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
|
hidden_states = hidden_states.squeeze(0).to(device).requires_grad_()
|
||||||
|
|
||||||
# forward
|
hid2 = hidden_states
|
||||||
for _, block in enumerate(blocks):
|
output_list = []
|
||||||
block = block.to(torch.bfloat16)
|
for i in range(10):
|
||||||
block = block.to(device)
|
hidden_states = hid2
|
||||||
hidden_states = block(
|
# forward
|
||||||
hidden_states,
|
for _, block in enumerate(blocks):
|
||||||
cu_seqlens=cu_seqlens,
|
block = block.to(torch.bfloat16)
|
||||||
indexes=indexes,
|
block = block.to(device)
|
||||||
inference_params=None,
|
hidden_states = block(
|
||||||
max_seqlen=max_seqlen,
|
hidden_states,
|
||||||
)
|
cu_seqlens=cu_seqlens,
|
||||||
|
indexes=indexes,
|
||||||
|
inference_params=None,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
result = hidden_states
|
||||||
|
output_list.append(result)
|
||||||
|
|
||||||
|
# check only forward logits
|
||||||
|
first_output = output_list[0]
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert torch.equal(first_output, output_list[i])
|
||||||
|
|
||||||
result = hidden_states
|
|
||||||
standard_result = torch.tensor(
|
standard_result = torch.tensor(
|
||||||
[
|
[
|
||||||
[-1.1621, 1.3111, 0.1509, 2.2697],
|
[-1.1621, 1.3111, 0.1509, 2.2697],
|
||||||
|
@ -248,8 +257,16 @@ def check_head(args):
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# forward
|
output_list = []
|
||||||
result = head(hidden_states)
|
for _ in range(10):
|
||||||
|
# forward
|
||||||
|
result = head(hidden_states)
|
||||||
|
output_list.append(result)
|
||||||
|
|
||||||
|
# check only forward logits
|
||||||
|
first_output = output_list[0]
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert torch.equal(first_output, output_list[i])
|
||||||
|
|
||||||
# check output
|
# check output
|
||||||
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
||||||
|
@ -334,8 +351,16 @@ def check_gather_forward(args):
|
||||||
requires_grad=True,
|
requires_grad=True,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# forward
|
output_list = []
|
||||||
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
for _ in range(10):
|
||||||
|
# forward
|
||||||
|
result = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||||
|
output_list.append(result)
|
||||||
|
|
||||||
|
# check only forward logits
|
||||||
|
first_output = output_list[0]
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert torch.equal(first_output, output_list[i])
|
||||||
|
|
||||||
# check output
|
# check output
|
||||||
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol)
|
||||||
|
|
|
@ -37,7 +37,15 @@ def check_norm(args):
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
result = norm(hidden_states.float())
|
output_list = []
|
||||||
|
for _ in range(10):
|
||||||
|
result = norm(hidden_states.float())
|
||||||
|
output_list.append(result)
|
||||||
|
|
||||||
|
# check only forward logits
|
||||||
|
first_output = output_list[0]
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert torch.equal(first_output, output_list[i])
|
||||||
|
|
||||||
standard = torch.tensor(
|
standard = torch.tensor(
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue