mirror of https://github.com/InternLM/InternLM
fix the type_ids when micro_num=1 and use_flash_attn=False (#516)
parent
112c34ae09
commit
809ad9ebc8
|
@ -24,13 +24,16 @@ def get_dataset_type_id(dataset_type_ids_map, path):
|
|||
return match_idxes[0]
|
||||
|
||||
|
||||
def unpack_data(input_ids, cu_seqlens):
|
||||
"""
|
||||
input_ids: (n, packed_length)
|
||||
Return:
|
||||
output: (batch_size, max_length)
|
||||
def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
|
||||
"""
|
||||
input_ids: if input_ids is not type_ids, the shape is (1, packed_length)
|
||||
else the shape is (micro_num, packed_length)
|
||||
is_type_ids: whether the input_ids is type_ids
|
||||
|
||||
Return:
|
||||
output: if input_ids is not type ids, the shape is (micro_bsz, max_length)
|
||||
else the shape is (micro_num, micro_bsz, max_length)
|
||||
"""
|
||||
bsz = input_ids.shape[0]
|
||||
|
||||
num_sequence = gpc.config.data["micro_bsz"]
|
||||
|
@ -45,7 +48,8 @@ def unpack_data(input_ids, cu_seqlens):
|
|||
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
||||
outputs[i] = output
|
||||
|
||||
if bsz == 1:
|
||||
# if the input_ids is not type_ids, we need squeeze the first dimension if it is 1.
|
||||
if bsz == 1 and not is_type_ids:
|
||||
outputs = outputs.squeeze(0)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -368,7 +368,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
|
|||
if batch[0].get("type_ids", None) is not None:
|
||||
# if use_flash_attn is False, we need to unpack type_ids
|
||||
if not gpc.config.model.use_flash_attn:
|
||||
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
|
||||
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True)
|
||||
|
||||
return batch, train_iter
|
||||
|
||||
|
|
Loading…
Reference in New Issue