mirror of https://github.com/InternLM/InternLM
fix(test): fix type_ids unpack bug (#530)
parent
828033aed5
commit
81ffb3d824
|
@ -64,20 +64,25 @@ def do_warmup(args):
|
||||||
packed_length = micro_bsz * sql
|
packed_length = micro_bsz * sql
|
||||||
for i in range(init_config.data.total_steps):
|
for i in range(init_config.data.total_steps):
|
||||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||||
input_shape = batch[0]["type_ids"].shape
|
type_ids_shape = batch[0]["type_ids"].shape
|
||||||
tokens_num = np.prod(input_shape)
|
input_ids_shape = batch[0]["input_ids"].shape
|
||||||
|
|
||||||
|
tokens_num = np.prod(input_ids_shape)
|
||||||
|
|
||||||
|
# If not use fa, 'type_ids' is unpcaked when load_new_batch is calling.
|
||||||
|
# However, 'input_ids' is unpcaked in pp/nopp engine.
|
||||||
if not init_config.model.use_flash_attn:
|
if not init_config.model.use_flash_attn:
|
||||||
if answer[i] > 1:
|
assert type_ids_shape == torch.Size(
|
||||||
assert input_shape == torch.Size(
|
[answer[i], micro_bsz, sql]
|
||||||
[answer[i], micro_bsz, sql]
|
), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], micro_bsz, sql])}"
|
||||||
), f"iter:{i}, {input_shape} != {[answer[i], micro_bsz, sql]}"
|
|
||||||
else:
|
|
||||||
assert input_shape == torch.Size([micro_bsz, sql]), f"iter:{i}, {input_shape} != {[micro_bsz, sql]}"
|
|
||||||
else:
|
else:
|
||||||
assert input_shape == torch.Size(
|
assert type_ids_shape == torch.Size(
|
||||||
[answer[i], packed_length]
|
[answer[i], packed_length]
|
||||||
), f"iter:{i}, {input_shape} != {torch.Size([answer[i], packed_length])}"
|
), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], packed_length])}"
|
||||||
|
|
||||||
|
assert input_ids_shape == torch.Size(
|
||||||
|
[answer[i], packed_length]
|
||||||
|
), f"iter:{i}, input_ids_shape: {input_ids_shape} != {torch.Size([answer[i], packed_length])}"
|
||||||
|
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
print(
|
print(
|
||||||
|
|
Loading…
Reference in New Issue