diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index 5ded6d843..2086ff003 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -8,7 +8,7 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.opt import OPTRM, OPTActor, OPTCritic -from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic +from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM from coati.trainer import PPOTrainer from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.utils import prepare_llama_tokenizer_and_embedding @@ -143,6 +143,8 @@ def main(args): prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384) if dist.is_initialized() and dist.get_world_size() > 1: prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) + else: + prompt_sampler = None prompt_dataloader = DataLoader(prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, @@ -151,6 +153,8 @@ def main(args): pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) + else: + pretrain_sampler = None pretrain_dataloader = DataLoader(pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler,