mirror of https://github.com/hpcaitech/ColossalAI
[chat] fix bugs in stage 3 training (#3759)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>pull/3770/head
parent
5dd573c6b6
commit
05759839bd
|
@ -45,7 +45,7 @@ class PromptDataset(Dataset):
|
|||
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keyed_prompt)
|
||||
return len(self.keyed_prompt["input_ids"])
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return {k: v[i] for k, v in self.keyed_prompt.items()}
|
||||
|
|
|
@ -153,7 +153,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
|
|||
--rm_path /your/rm/model/path
|
||||
```
|
||||
|
||||
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
|
||||
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
|
||||
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
|
||||
|
||||
### Arg List
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild
|
||||
reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3
|
||||
|
||||
data = ''
|
||||
with open(jsonl_file, 'r', encoding="utf-8") as f1:
|
||||
for jsonstr in f1.readlines():
|
||||
jsonstr = '\t' + jsonstr.strip('\n') + ',\n'
|
||||
data = data + jsonstr
|
||||
data = '[\n' + data + ']'
|
||||
|
||||
with open(reformat_file, 'w') as f2:
|
||||
f2.write(data)
|
|
@ -0,0 +1,30 @@
|
|||
import argparse
|
||||
|
||||
import random
|
||||
import json
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def sample(args):
|
||||
with open(args.dataset_path, mode='r') as f:
|
||||
dataset_list = json.load(f)
|
||||
|
||||
sampled_dataset = [{"instruction": sample["instruction"], "id":idx}
|
||||
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))]
|
||||
|
||||
with open(args.save_path, mode='w') as f:
|
||||
json.dump(sampled_dataset, f, indent=4,
|
||||
default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_path', type=str, default=None,
|
||||
required=True, help="path to the pretrain dataset")
|
||||
parser.add_argument('--save_path', type=str, default='prompt.json',
|
||||
help="path to save the prompt dataset")
|
||||
parser.add_argument('--sample_size', type=int,
|
||||
default=16384, help="size of the prompt dataset")
|
||||
args = parser.parse_args()
|
||||
sample(args)
|
Loading…
Reference in New Issue