mirror of https://github.com/hpcaitech/ColossalAI
fix style
parent
de1bf08ed0
commit
6fd9e86864
|
@ -193,7 +193,7 @@ def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
|
||||||
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
|
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
|
||||||
)
|
)
|
||||||
# no truncation applied
|
# no truncation applied
|
||||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=int(1e10))
|
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss, max_length=None)
|
||||||
|
|
||||||
loss_mask = [0] * len(tokenized)
|
loss_mask = [0] * len(tokenized)
|
||||||
label_decode = []
|
label_decode = []
|
||||||
|
|
|
@ -119,17 +119,18 @@ def tokenize_and_concatenate(
|
||||||
loss_ends = []
|
loss_ends = []
|
||||||
for s, r in zip(text, require_loss):
|
for s, r in zip(text, require_loss):
|
||||||
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
|
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
|
||||||
if len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
|
if not max_length or len(input_ids) + len(tokenized) <= max_length or len(loss_ends) == 0:
|
||||||
if r:
|
if r:
|
||||||
loss_starts.append(len(input_ids))
|
loss_starts.append(len(input_ids))
|
||||||
loss_ends.append(len(input_ids) + len(tokenized))
|
loss_ends.append(len(input_ids) + len(tokenized))
|
||||||
input_ids.extend(tokenized)
|
input_ids.extend(tokenized)
|
||||||
if loss_starts[0] >= max_length:
|
if max_length and loss_starts[0] >= max_length:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
if discard_non_loss_tokens_at_tail:
|
if discard_non_loss_tokens_at_tail:
|
||||||
input_ids = input_ids[: loss_ends[-1]]
|
input_ids = input_ids[: loss_ends[-1]]
|
||||||
input_ids = input_ids[:max_length]
|
if max_length:
|
||||||
loss_ends[-1] = min(max_length, loss_ends[-1])
|
input_ids = input_ids[:max_length]
|
||||||
|
loss_ends[-1] = min(max_length, loss_ends[-1])
|
||||||
return input_ids, loss_starts, loss_ends
|
return input_ids, loss_starts, loss_ends
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue