fix style

pull/5922/head
YeAnbang 2024-07-29 01:29:18 +00:00
parent de1bf08ed0
commit 6fd9e86864
2 changed files with 6 additions and 5 deletions

View File

@ -193,7 +193,7 @@ def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
)
# no truncation applied
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
loss_mask = [0] * len(tokenized)
label_decode = []

View File

@ -119,17 +119,18 @@ def tokenize_and_concatenate(
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
if loss_starts[0] >= max_length:
if max_length and loss_starts[0] >= max_length:
return None, None, None
if discard_non_loss_tokens_at_tail:
input_ids = input_ids[: loss_ends[-1]]
input_ids = input_ids[:max_length]
loss_ends[-1] = min(max_length, loss_ends[-1])
if max_length:
input_ids = input_ids[:max_length]
loss_ends[-1] = min(max_length, loss_ends[-1])
return input_ids, loss_starts, loss_ends