mirror of https://github.com/hpcaitech/ColossalAI
[chat] fix bugs in stage 3 training (#3759)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>pull/3770/head
parent
5dd573c6b6
commit
05759839bd
|
@ -45,7 +45,7 @@ class PromptDataset(Dataset):
|
||||||
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
|
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.keyed_prompt)
|
return len(self.keyed_prompt["input_ids"])
|
||||||
|
|
||||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||||
return {k: v[i] for k, v in self.keyed_prompt.items()}
|
return {k: v[i] for k, v in self.keyed_prompt.items()}
|
||||||
|
|
|
@ -153,7 +153,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
|
||||||
--rm_path /your/rm/model/path
|
--rm_path /your/rm/model/path
|
||||||
```
|
```
|
||||||
|
|
||||||
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
|
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/generate_prompt_dataset.py) which samples `instinwild_en.json` or `instinwild_ch.json` in [InstructionWild](https://github.com/XueFuzhao/InstructionWild/tree/main/data#instructwild-data) to generate the prompt dataset.
|
||||||
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
|
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
|
||||||
|
|
||||||
### Arg List
|
### Arg List
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild
|
|
||||||
reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3
|
|
||||||
|
|
||||||
data = ''
|
|
||||||
with open(jsonl_file, 'r', encoding="utf-8") as f1:
|
|
||||||
for jsonstr in f1.readlines():
|
|
||||||
jsonstr = '\t' + jsonstr.strip('\n') + ',\n'
|
|
||||||
data = data + jsonstr
|
|
||||||
data = '[\n' + data + ']'
|
|
||||||
|
|
||||||
with open(reformat_file, 'w') as f2:
|
|
||||||
f2.write(data)
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(args):
|
||||||
|
with open(args.dataset_path, mode='r') as f:
|
||||||
|
dataset_list = json.load(f)
|
||||||
|
|
||||||
|
sampled_dataset = [{"instruction": sample["instruction"], "id":idx}
|
||||||
|
for idx, sample in enumerate(random.sample(dataset_list, args.sample_size))]
|
||||||
|
|
||||||
|
with open(args.save_path, mode='w') as f:
|
||||||
|
json.dump(sampled_dataset, f, indent=4,
|
||||||
|
default=str, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None,
|
||||||
|
required=True, help="path to the pretrain dataset")
|
||||||
|
parser.add_argument('--save_path', type=str, default='prompt.json',
|
||||||
|
help="path to save the prompt dataset")
|
||||||
|
parser.add_argument('--sample_size', type=int,
|
||||||
|
default=16384, help="size of the prompt dataset")
|
||||||
|
args = parser.parse_args()
|
||||||
|
sample(args)
|
Loading…
Reference in New Issue