From 6fd9e8686409ff6f96f49dd63570dfcadee2284e Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 29 Jul 2024 01:29:18 +0000 Subject: [PATCH] fix style --- .../ColossalChat/coati/dataset/tokenization_utils.py | 2 +- applications/ColossalChat/coati/dataset/utils.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/applications/ColossalChat/coati/dataset/tokenization_utils.py b/applications/ColossalChat/coati/dataset/tokenization_utils.py index 2cbf11d1f..9eb2eba87 100755 --- a/applications/ColossalChat/coati/dataset/tokenization_utils.py +++ b/applications/ColossalChat/coati/dataset/tokenization_utils.py @@ -193,7 +193,7 @@ def apply_rlhf_data_format(template: Conversation, tokenizer: Any): template.messages[: 2 * target_turn], prompt, template.end_of_assistant ) # no truncation applied - tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10)) + tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None) loss_mask = [0] * len(tokenized) label_decode = [] diff --git a/applications/ColossalChat/coati/dataset/utils.py b/applications/ColossalChat/coati/dataset/utils.py index cf767b444..42c3191db 100755 --- a/applications/ColossalChat/coati/dataset/utils.py +++ b/applications/ColossalChat/coati/dataset/utils.py @@ -119,17 +119,18 @@ def tokenize_and_concatenate( loss_ends = [] for s, r in zip(text, require_loss): tokenized = tokenizer(s, add_special_tokens=False)["input_ids"] - if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0: + if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0: if r: loss_starts.append(len(input_ids)) loss_ends.append(len(input_ids) + len(tokenized)) input_ids.extend(tokenized) - if loss_starts[0] >= max_length: + if max_length and loss_starts[0] >= max_length: return None, None, None if discard_non_loss_tokens_at_tail: input_ids = input_ids[: loss_ends[-1]] - input_ids = input_ids[:max_length] - loss_ends[-1] = min(max_length, loss_ends[-1]) + if max_length: + input_ids = input_ids[:max_length] + loss_ends[-1] = min(max_length, loss_ends[-1]) return input_ids, loss_starts, loss_ends