mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] fix ppo training hanging problem with gemini (#3162)
* [chatgpt] fix generation early stopping * [chatgpt] fix train prompts examplepull/3166/head
parent
6ae8ed0407
commit
c474fda282
|
@ -1,6 +1,7 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
|
@ -27,6 +28,14 @@ def prepare_logits_processor(top_k: Optional[int] = None,
|
|||
return processor_list
|
||||
|
||||
|
||||
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
# consider DP
|
||||
unfinished_sequences = unfinished_sequences.clone()
|
||||
dist.all_reduce(unfinished_sequences)
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
|
@ -74,7 +83,7 @@ def sample(model: nn.Module,
|
|||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
|
||||
# stop when each sentence is finished if early_stopping=True
|
||||
if early_stopping and unfinished_sequences.max() == 0:
|
||||
if early_stopping and _is_sequence_finished(unfinished_sequences):
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
|
|
@ -46,7 +46,6 @@ def main(args):
|
|||
initial_model = deepcopy(actor)
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
|
@ -70,7 +69,9 @@ def main(args):
|
|||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
||||
|
||||
def tokenize_fn(texts):
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
|
|
Loading…
Reference in New Issue