From 809ad9ebc8146a846ca9ce507eb518f4e33b405a Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:38:28 +0800 Subject: [PATCH] fix the type_ids when micro_num=1 and use_flash_attn=False (#516) --- internlm/data/utils.py | 16 ++++++++++------ internlm/train/training_internlm.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index fbcb6f7..92d08f3 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -24,13 +24,16 @@ def get_dataset_type_id(dataset_type_ids_map, path): return match_idxes[0] -def unpack_data(input_ids, cu_seqlens): - """ - input_ids: (n, packed_length) - Return: - output: (batch_size, max_length) +def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False): """ + input_ids: if input_ids is not type_ids, the shape is (1, packed_length) + else the shape is (micro_num, packed_length) + is_type_ids: whether the input_ids is type_ids + Return: + output: if input_ids is not type ids, the shape is (micro_bsz, max_length) + else the shape is (micro_num, micro_bsz, max_length) + """ bsz = input_ids.shape[0] num_sequence = gpc.config.data["micro_bsz"] @@ -45,7 +48,8 @@ def unpack_data(input_ids, cu_seqlens): output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] outputs[i] = output - if bsz == 1: + # if the input_ids is not type_ids, we need squeeze the first dimension if it is 1. + if bsz == 1 and not is_type_ids: outputs = outputs.squeeze(0) return outputs diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 89e2d06..474bfd2 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -368,7 +368,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai if batch[0].get("type_ids", None) is not None: # if use_flash_attn is False, we need to unpack type_ids if not gpc.config.model.use_flash_attn: - batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"]) + batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True) return batch, train_iter