fix the type_ids when micro_num=1 and use_flash_attn=False (#516)

pull/530/head
ytxiong 2023-12-06 14:38:28 +08:00 committed by GitHub
parent 112c34ae09
commit 809ad9ebc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 7 deletions

View File

@ -24,13 +24,16 @@ def get_dataset_type_id(dataset_type_ids_map, path):
return match_idxes[0]
def unpack_data(input_ids, cu_seqlens):
"""
input_ids: (n, packed_length)
Return:
output: (batch_size, max_length)
def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
"""
input_ids: if input_ids is not type_ids, the shape is (1, packed_length)
else the shape is (micro_num, packed_length)
is_type_ids: whether the input_ids is type_ids
Return:
output: if input_ids is not type ids, the shape is (micro_bsz, max_length)
else the shape is (micro_num, micro_bsz, max_length)
"""
bsz = input_ids.shape[0]
num_sequence = gpc.config.data["micro_bsz"]
@ -45,7 +48,8 @@ def unpack_data(input_ids, cu_seqlens):
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
outputs[i] = output
if bsz == 1:
# if the input_ids is not type_ids, we need squeeze the first dimension if it is 1.
if bsz == 1 and not is_type_ids:
outputs = outputs.squeeze(0)
return outputs

View File

@ -368,7 +368,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
if batch[0].get("type_ids", None) is not None:
# if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn:
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True)
return batch, train_iter