Merge pull request #3621 from zhang-yi-chi/fix/chat-train-prompts-single-gpu

[chat] fix single gpu training bug in examples/train_prompts.py
pull/3652/head
Tong Li 2023-04-24 22:13:54 +08:00 committed by GitHub
commit e1b0a78afa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -8,7 +8,7 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
from coati.trainer import PPOTrainer from coati.trainer import PPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding from coati.utils import prepare_llama_tokenizer_and_embedding
@ -143,6 +143,8 @@ def main(args):
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384) prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset, prompt_dataloader = DataLoader(prompt_dataset,
shuffle=(prompt_sampler is None), shuffle=(prompt_sampler is None),
sampler=prompt_sampler, sampler=prompt_sampler,
@ -151,6 +153,8 @@ def main(args):
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384) pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset, pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None), shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler, sampler=pretrain_sampler,