mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
139 lines
4.8 KiB
139 lines
4.8 KiB
import io
|
|
import json
|
|
from typing import Any, Dict, List
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
|
def is_rank_0() -> bool:
|
|
return not dist.is_initialized() or dist.get_rank() == 0
|
|
|
|
|
|
def _make_r_io_base(f, mode: str):
|
|
if not isinstance(f, io.IOBase):
|
|
f = open(f, mode=mode)
|
|
return f
|
|
|
|
|
|
def jload(f, mode="r"):
|
|
"""Load a .json file into a dictionary."""
|
|
f = _make_r_io_base(f, mode)
|
|
jdict = json.load(f)
|
|
f.close()
|
|
return jdict
|
|
|
|
|
|
def read_string_by_schema(data: Dict[str, Any], schema: str) -> str:
|
|
"""
|
|
Read a feild of the dataset be schema
|
|
Args:
|
|
data: Dict[str, Any]
|
|
schema: cascaded feild names seperated by '.'. e.g. person.name.first will access data['person']['name']['first']
|
|
"""
|
|
keys = schema.split(".")
|
|
result = data
|
|
for key in keys:
|
|
result = result.get(key, None)
|
|
if result is None:
|
|
return ""
|
|
assert isinstance(result, str), f"dataset element is not a string: {result}"
|
|
return result
|
|
|
|
|
|
def pad_to_max_len(
|
|
sequence: List[torch.Tensor], max_length: int, padding_value: int, batch_first: bool = True, padding_side="left"
|
|
):
|
|
"""
|
|
Args:
|
|
sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
|
|
"""
|
|
if padding_side == "left":
|
|
reversed_sequence = [seq.flip(dims=(0,)) for seq in sequence]
|
|
padded = torch.nn.utils.rnn.pad_sequence(
|
|
sequences=reversed_sequence, batch_first=batch_first, padding_value=padding_value
|
|
)
|
|
to_pad = max_length - padded.size(1)
|
|
padded = F.pad(padded, (0, to_pad), value=padding_value)
|
|
return torch.flip(padded, dims=(1,))
|
|
elif padding_side == "right":
|
|
padded = torch.nn.utils.rnn.pad_sequence(
|
|
sequences=sequence, batch_first=batch_first, padding_value=padding_value
|
|
)
|
|
to_pad = max_length - padded.size(1)
|
|
return F.pad(padded, (0, to_pad), value=padding_value)
|
|
else:
|
|
raise RuntimeError(f"`padding_side` can only be `left` or `right`, " f"but now `{padding_side}`")
|
|
|
|
|
|
def chuncate_sequence(sequence: List[torch.Tensor], max_length: int, dtype: Any):
|
|
"""
|
|
Args:
|
|
sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
|
|
"""
|
|
return [
|
|
torch.Tensor(seq[:max_length]).to(dtype) if len(seq) > max_length else torch.Tensor(seq).to(dtype)
|
|
for seq in sequence
|
|
]
|
|
|
|
|
|
def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, start_index: int = 0) -> int:
|
|
if subseq is None:
|
|
return 0
|
|
for i in range(start_index, len(seq) - len(subseq) + 1):
|
|
if torch.all(seq[i : i + len(subseq)] == subseq):
|
|
return i
|
|
return -1
|
|
|
|
|
|
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
|
|
"""
|
|
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
|
|
|
|
Args:
|
|
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
|
|
text (List[str]): The list of texts to tokenize.
|
|
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
|
|
|
|
Returns:
|
|
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
|
|
the start positions of loss spans, and the end positions of loss spans.
|
|
"""
|
|
input_ids = []
|
|
loss_starts = []
|
|
loss_ends = []
|
|
for s, r in zip(text, require_loss):
|
|
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
|
|
if r:
|
|
loss_starts.append(len(input_ids))
|
|
loss_ends.append(len(input_ids) + len(tokenized))
|
|
input_ids.extend(tokenized)
|
|
return input_ids, loss_starts, loss_ends
|
|
|
|
|
|
def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str):
|
|
# Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate
|
|
start_idx = 0
|
|
chunks = []
|
|
require_loss = []
|
|
for line in messages:
|
|
first_occur = prompt.find(line["content"], start_idx)
|
|
if prompt[first_occur - 1] != " ":
|
|
chunks.append(prompt[start_idx:first_occur])
|
|
chunks.append(prompt[first_occur : first_occur + len(line["content"])])
|
|
else:
|
|
chunks.append(prompt[start_idx : first_occur - 1])
|
|
chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])])
|
|
start_idx = first_occur + len(line["content"])
|
|
if line["role"].lower() == "assistant":
|
|
require_loss.append(False)
|
|
require_loss.append(True)
|
|
else:
|
|
require_loss.append(False)
|
|
require_loss.append(False)
|
|
chunks.append(prompt[start_idx:])
|
|
require_loss.append(False)
|
|
return chunks, require_loss
|