fix(test): fix type_ids unpack bug (#530)

pull/531/head
Guoteng 2023-12-07 18:47:19 +08:00 committed by GitHub
parent 828033aed5
commit 81ffb3d824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 10 deletions

View File

@ -64,20 +64,25 @@ def do_warmup(args):
packed_length = micro_bsz * sql packed_length = micro_bsz * sql
for i in range(init_config.data.total_steps): for i in range(init_config.data.total_steps):
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
input_shape = batch[0]["type_ids"].shape type_ids_shape = batch[0]["type_ids"].shape
tokens_num = np.prod(input_shape) input_ids_shape = batch[0]["input_ids"].shape
tokens_num = np.prod(input_ids_shape)
# If not use fa, 'type_ids' is unpcaked when load_new_batch is calling.
# However, 'input_ids' is unpcaked in pp/nopp engine.
if not init_config.model.use_flash_attn: if not init_config.model.use_flash_attn:
if answer[i] > 1: assert type_ids_shape == torch.Size(
assert input_shape == torch.Size( [answer[i], micro_bsz, sql]
[answer[i], micro_bsz, sql] ), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], micro_bsz, sql])}"
), f"iter:{i}, {input_shape} != {[answer[i], micro_bsz, sql]}"
else:
assert input_shape == torch.Size([micro_bsz, sql]), f"iter:{i}, {input_shape} != {[micro_bsz, sql]}"
else: else:
assert input_shape == torch.Size( assert type_ids_shape == torch.Size(
[answer[i], packed_length] [answer[i], packed_length]
), f"iter:{i}, {input_shape} != {torch.Size([answer[i], packed_length])}" ), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], packed_length])}"
assert input_ids_shape == torch.Size(
[answer[i], packed_length]
), f"iter:{i}, input_ids_shape: {input_ids_shape} != {torch.Size([answer[i], packed_length])}"
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
print( print(