Browse Source

[chat] fix enable single gpu training bug

pull/3621/head
zhang-yi-chi 2 years ago
parent
commit
739cfe3360
  1. 6
      applications/Chat/examples/train_prompts.py

6
applications/Chat/examples/train_prompts.py

@ -8,7 +8,7 @@ from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic
from coati.models.roberta import RoBERTaActor, RoBERTaCritic, RoBERTaRM
from coati.trainer import PPOTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
@ -143,6 +143,8 @@ def main(args):
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1:
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset,
shuffle=(prompt_sampler is None),
sampler=prompt_sampler,
@ -151,6 +153,8 @@ def main(args):
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,

Loading…
Cancel
Save