[chatgpt] fix ppo training hanging problem with gemini (#3162)

* [chatgpt] fix generation early stopping

* [chatgpt] fix train prompts example
pull/3166/head
ver217 2023-03-17 15:41:47 +08:00 committed by GitHub
parent 6ae8ed0407
commit c474fda282
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
try:
@ -27,6 +28,14 @@ def prepare_logits_processor(top_k: Optional[int] = None,
return processor_list
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
if dist.is_initialized() and dist.get_world_size() > 1:
# consider DP
unfinished_sequences = unfinished_sequences.clone()
dist.all_reduce(unfinished_sequences)
return unfinished_sequences.max() == 0
def sample(model: nn.Module,
input_ids: torch.Tensor,
max_length: int,
@ -74,7 +83,7 @@ def sample(model: nn.Module,
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished if early_stopping=True
if early_stopping and unfinished_sequences.max() == 0:
if early_stopping and _is_sequence_finished(unfinished_sequences):
break
return input_ids

View File

@ -46,7 +46,6 @@ def main(args):
initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
# configure optimizer
if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
@ -70,7 +69,9 @@ def main(args):
dataset = pd.read_csv(args.prompt_path)['prompt']
def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.cuda() for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
@ -101,7 +102,7 @@ def main(args):
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt: