fix(test): fix type_ids unpack bug (#530)

pull/531/head
Guoteng 2023-12-07 18:47:19 +08:00 committed by GitHub
parent 828033aed5
commit 81ffb3d824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 10 deletions

View File

@ -64,20 +64,25 @@ def do_warmup(args):
packed_length = micro_bsz * sql
for i in range(init_config.data.total_steps):
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
input_shape = batch[0]["type_ids"].shape
tokens_num = np.prod(input_shape)
type_ids_shape = batch[0]["type_ids"].shape
input_ids_shape = batch[0]["input_ids"].shape
tokens_num = np.prod(input_ids_shape)
# If not use fa, 'type_ids' is unpcaked when load_new_batch is calling.
# However, 'input_ids' is unpcaked in pp/nopp engine.
if not init_config.model.use_flash_attn:
if answer[i] > 1:
assert input_shape == torch.Size(
[answer[i], micro_bsz, sql]
), f"iter:{i}, {input_shape} != {[answer[i], micro_bsz, sql]}"
else:
assert input_shape == torch.Size([micro_bsz, sql]), f"iter:{i}, {input_shape} != {[micro_bsz, sql]}"
assert type_ids_shape == torch.Size(
[answer[i], micro_bsz, sql]
), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], micro_bsz, sql])}"
else:
assert input_shape == torch.Size(
assert type_ids_shape == torch.Size(
[answer[i], packed_length]
), f"iter:{i}, {input_shape} != {torch.Size([answer[i], packed_length])}"
), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], packed_length])}"
assert input_ids_shape == torch.Size(
[answer[i], packed_length]
), f"iter:{i}, input_ids_shape: {input_ids_shape} != {torch.Size([answer[i], packed_length])}"
if gpc.get_global_rank() == 0:
print(