You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalChat/coati/dataset/utils.py

139 lines
4.8 KiB

import io
import json
from typing import Any, Dict, List
import torch
import torch.distributed as dist
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
def read_string_by_schema(data: Dict[str, Any], schema: str) -> str:
"""
Read a feild of the dataset be schema
Args:
data: Dict[str, Any]
schema: cascaded feild names seperated by '.'. e.g. person.name.first will access data['person']['name']['first']
"""
keys = schema.split(".")
result = data
for key in keys:
result = result.get(key, None)
if result is None:
return ""
assert isinstance(result, str), f"dataset element is not a string: {result}"
return result
def pad_to_max_len(
sequence: List[torch.Tensor], max_length: int, padding_value: int, batch_first: bool = True, padding_side="left"
):
"""
Args:
sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
"""
if padding_side == "left":
reversed_sequence = [seq.flip(dims=(0,)) for seq in sequence]
padded = torch.nn.utils.rnn.pad_sequence(
sequences=reversed_sequence, batch_first=batch_first, padding_value=padding_value
)
to_pad = max_length - padded.size(1)
padded = F.pad(padded, (0, to_pad), value=padding_value)
return torch.flip(padded, dims=(1,))
elif padding_side == "right":
padded = torch.nn.utils.rnn.pad_sequence(
sequences=sequence, batch_first=batch_first, padding_value=padding_value
)
to_pad = max_length - padded.size(1)
return F.pad(padded, (0, to_pad), value=padding_value)
else:
raise RuntimeError(f"`padding_side` can only be `left` or `right`, " f"but now `{padding_side}`")
def chuncate_sequence(sequence: List[torch.Tensor], max_length: int, dtype: Any):
"""
Args:
sequence: a batch of tensor of shape [batch_size, seq_len] if batch_first==True
"""
return [
torch.Tensor(seq[:max_length]).to(dtype) if len(seq) > max_length else torch.Tensor(seq).to(dtype)
for seq in sequence
]
def find_first_occurrence_subsequence(seq: torch.Tensor, subseq: torch.Tensor, start_index: int = 0) -> int:
if subseq is None:
return 0
for i in range(start_index, len(seq) - len(subseq) + 1):
if torch.all(seq[i : i + len(subseq)] == subseq):
return i
return -1
def tokenize_and_concatenate(tokenizer: PreTrainedTokenizer, text: List[str], require_loss: List[bool]):
"""
Tokenizes a list of texts using the provided tokenizer and concatenates the tokenized outputs.
Args:
tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenization.
text (List[str]): The list of texts to tokenize.
require_loss (List[bool]): A list of boolean values indicating whether each text requires loss calculation.
Returns:
Tuple[List[int], List[int], List[int]]: A tuple containing the concatenated tokenized input ids,
the start positions of loss spans, and the end positions of loss spans.
"""
input_ids = []
loss_starts = []
loss_ends = []
for s, r in zip(text, require_loss):
tokenized = tokenizer(s, add_special_tokens=False)["input_ids"]
if r:
loss_starts.append(len(input_ids))
loss_ends.append(len(input_ids) + len(tokenized))
input_ids.extend(tokenized)
return input_ids, loss_starts, loss_ends
def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: str):
# Seperate templated prompt into chunks by human/assistant's lines, prepare data for tokenize_and_concatenate
start_idx = 0
chunks = []
require_loss = []
for line in messages:
first_occur = prompt.find(line["content"], start_idx)
if prompt[first_occur - 1] != " ":
chunks.append(prompt[start_idx:first_occur])
chunks.append(prompt[first_occur : first_occur + len(line["content"])])
else:
chunks.append(prompt[start_idx : first_occur - 1])
chunks.append(prompt[first_occur - 1 : first_occur + len(line["content"])])
start_idx = first_occur + len(line["content"])
if line["role"].lower() == "assistant":
require_loss.append(False)
require_loss.append(True)
else:
require_loss.append(False)
require_loss.append(False)
chunks.append(prompt[start_idx:])
require_loss.append(False)
return chunks, require_loss