mirror of https://github.com/hpcaitech/ColossalAI
[Chat] fix the tokenizer "int too big to convert" error in SFT training (#3453)
* Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commitpull/3219/head06741d894d
. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit06741d894d
. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * update roberta with coati * chat ci update * Revert "chat ci update" This reverts commit 17ae7ae01fa752bd3289fc39069868fde99cf846. * [Chat] fix the tokenizer "int too big to convert" error in SFT training fix the tokenizer error during SFT training using Bloom and OPT
parent
46c009dba4
commit
72cb4dd433
|
@ -78,14 +78,14 @@ class SFTDataset(Dataset):
|
|||
# return dict(self.prompts[idx], self.prompts[idx])
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
]
|
||||
|
@ -105,10 +105,11 @@ def preprocess(
|
|||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
|
@ -119,7 +120,7 @@ def preprocess(
|
|||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None):
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
|
@ -138,7 +139,7 @@ class SupervisedDataset(Dataset):
|
|||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
data_dict = preprocess(sources, targets, tokenizer)
|
||||
data_dict = preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
|
|
|
@ -71,6 +71,7 @@ def train(args):
|
|||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
max_len = args.max_len
|
||||
if args.model == 'llama':
|
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
|
||||
|
||||
|
@ -99,13 +100,14 @@ def train(args):
|
|||
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
|
||||
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
|
||||
|
||||
train_dataset = SFTDataset(train_data, tokenizer)
|
||||
eval_dataset = SFTDataset(eval_data, tokenizer)
|
||||
train_dataset = SFTDataset(train_data, tokenizer, max_len)
|
||||
eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
|
||||
|
||||
else:
|
||||
train_dataset = SupervisedDataset(tokenizer=tokenizer,
|
||||
data_path=args.dataset,
|
||||
max_datasets_size=args.max_datasets_size)
|
||||
max_datasets_size=args.max_datasets_size,
|
||||
max_length=max_len)
|
||||
eval_dataset = None
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
|
@ -176,6 +178,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--max_len', type=int, default=512)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
|
|
Loading…
Reference in New Issue