mirror of https://github.com/hpcaitech/ColossalAI
fix style
parent
de1bf08ed0
commit
6fd9e86864
|
@ -193,7 +193,7 @@ def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
|
|||
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
|
||||
)
|
||||
# no truncation applied
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
|
||||
|
||||
loss_mask = [0] * len(tokenized)
|
||||
label_decode = []
|
||||
|
|
|
@ -119,17 +119,18 @@ def tokenize_and_concatenate(
|
|||
loss_ends = []
|
||||
for s, r in zip(text, require_loss):
|
||||
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
|
||||
if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
|
||||
if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
|
||||
if r:
|
||||
loss_starts.append(len(input_ids))
|
||||
loss_ends.append(len(input_ids) + len(tokenized))
|
||||
input_ids.extend(tokenized)
|
||||
if loss_starts[0] >= max_length:
|
||||
if max_length and loss_starts[0] >= max_length:
|
||||
return None, None, None
|
||||
if discard_non_loss_tokens_at_tail:
|
||||
input_ids = input_ids[: loss_ends[-1]]
|
||||
input_ids = input_ids[:max_length]
|
||||
loss_ends[-1] = min(max_length, loss_ends[-1])
|
||||
if max_length:
|
||||
input_ids = input_ids[:max_length]
|
||||
loss_ends[-1] = min(max_length, loss_ends[-1])
|
||||
return input_ids, loss_starts, loss_ends
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue