[chat] fix bugs in stage 3 training (#3759)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3770/head
Yuanchen 2023-05-17 17:44:05 +08:00 committed by GitHub
parent 5dd573c6b6
commit 05759839bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 14 deletions

View File

@ -45,7 +45,7 @@ class PromptDataset(Dataset):
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
def __len__(self):
return len(self.keyed_prompt)
return len(self.keyed_prompt["input_ids"])
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return {k: v[i] for k, v in self.keyed_prompt.items()}

View File

@ -153,7 +153,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
--rm_path /your/rm/model/path
```
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
### Arg List

View File

@ -1,12 +0,0 @@
jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild
reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3
data = ''
with open(jsonl_file, 'r', encoding="utf-8") as f1:
for jsonstr in f1.readlines():
jsonstr = '\t' + jsonstr.strip('\n') + ',\n'
data = data + jsonstr
data = '[\n' + data + ']'
with open(reformat_file, 'w') as f2:
f2.write(data)

View File

@ -0,0 +1,30 @@
import argparse
import random
import json
random.seed(42)
def sample(args):
with open(args.dataset_path, mode='r') as f:
dataset_list = json.load(f)
sampled_dataset = [{"instruction": sample["instruction"], "id":idx}
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))]
with open(args.save_path, mode='w') as f:
json.dump(sampled_dataset, f, indent=4,
default=str, ensure_ascii=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default=None,
required=True, help="path to the pretrain dataset")
parser.add_argument('--save_path', type=str, default='prompt.json',
help="path to save the prompt dataset")
parser.add_argument('--sample_size', type=int,
default=16384, help="size of the prompt dataset")
args = parser.parse_args()
sample(args)