ColossalAI/applications/ColossalChat/tests/generate_dummy_datasets_for...

73 lines
2.9 KiB
Python
Raw Normal View History

[ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
2024-03-29 06:12:29 +00:00
import argparse
import json
import os
sft_seed = {
"messages": [
{"from": "human", "content": "Give three tips for staying healthy."},
{
"from": "assistant",
"content": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule.",
},
]
}
prompt_seed = {
"messages": [
{"from": "human", "content": "Describe the impacts of climate change on communities living in coastal areas."},
{
"from": "assistant",
"content": "Climate change has caused an increase in sea levels, which has caused coastal erosion and flooding of low-lying areas. This has led to displacement of people from their homes, as well as increased risk of epidemics of waterborne illnesses. Coastal cities have also seen an increase in extreme weather events such as hurricanes and tropical storms, which can cause extensive damage to infrastructure, homes, and businesses. As a result of climate change, some coastal areas are becoming uninhabitable, forcing communities to seek alternative living arrangements.",
},
]
}
preference_seed = {
"context": [
{"from": "human", "content": "What kind of noises did dinosaurs make?"},
{
"from": "assistant",
"content": "Humans and dinosaurs didn't live at the same time, so it's really hard to say. The best place to find out what noises dinosaurs made would be",
},
{"from": "human", "content": "yes they did"},
{
"from": "assistant",
"content": "to guess, and that would probably require lots of reading and a certain amount of imagination, so we're not really prepared to do that.",
},
{"from": "human", "content": "you cant read"},
],
"chosen": [{"from": "assistant", "content": "You can read?"}],
"rejected": [{"from": "assistant", "content": "there's a lot of stuff humans don't know"}],
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
type=str,
required=True,
default=None,
help="The output dir",
)
parser.add_argument(
"--data_type",
type=str,
required=True,
default=None,
help="The type of data",
)
args = parser.parse_args()
if args.data_type == "sft":
seed = sft_seed
elif args.data_type == "prompt":
seed = prompt_seed
elif args.data_type == "preference":
seed = preference_seed
else:
raise ValueError(f"Unknown data type {args.data_type}")
line = json.dumps(seed, ensure_ascii=False) + "\n"
for idx in [1, 2, 3]:
with open(os.path.join(args.data_dir, f"{idx}.jsonl"), "w", encoding="utf8") as f:
for i in range(1000):
f.write(line)
f.write(line)