mirror of https://github.com/InternLM/InternLM
fix the type_ids when micro_num=1 and use_flash_attn=False (#516)
parent
112c34ae09
commit
809ad9ebc8
|
@ -24,13 +24,16 @@ def get_dataset_type_id(dataset_type_ids_map, path):
|
||||||
return match_idxes[0]
|
return match_idxes[0]
|
||||||
|
|
||||||
|
|
||||||
def unpack_data(input_ids, cu_seqlens):
|
def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
|
||||||
"""
|
|
||||||
input_ids: (n, packed_length)
|
|
||||||
Return:
|
|
||||||
output: (batch_size, max_length)
|
|
||||||
"""
|
"""
|
||||||
|
input_ids: if input_ids is not type_ids, the shape is (1, packed_length)
|
||||||
|
else the shape is (micro_num, packed_length)
|
||||||
|
is_type_ids: whether the input_ids is type_ids
|
||||||
|
|
||||||
|
Return:
|
||||||
|
output: if input_ids is not type ids, the shape is (micro_bsz, max_length)
|
||||||
|
else the shape is (micro_num, micro_bsz, max_length)
|
||||||
|
"""
|
||||||
bsz = input_ids.shape[0]
|
bsz = input_ids.shape[0]
|
||||||
|
|
||||||
num_sequence = gpc.config.data["micro_bsz"]
|
num_sequence = gpc.config.data["micro_bsz"]
|
||||||
|
@ -45,7 +48,8 @@ def unpack_data(input_ids, cu_seqlens):
|
||||||
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
||||||
outputs[i] = output
|
outputs[i] = output
|
||||||
|
|
||||||
if bsz == 1:
|
# if the input_ids is not type_ids, we need squeeze the first dimension if it is 1.
|
||||||
|
if bsz == 1 and not is_type_ids:
|
||||||
outputs = outputs.squeeze(0)
|
outputs = outputs.squeeze(0)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -368,7 +368,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
|
||||||
if batch[0].get("type_ids", None) is not None:
|
if batch[0].get("type_ids", None) is not None:
|
||||||
# if use_flash_attn is False, we need to unpack type_ids
|
# if use_flash_attn is False, we need to unpack type_ids
|
||||||
if not gpc.config.model.use_flash_attn:
|
if not gpc.config.model.use_flash_attn:
|
||||||
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
|
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True)
|
||||||
|
|
||||||
return batch, train_iter
|
return batch, train_iter
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue