mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt] optimize generation kwargs (#2717)
* [chatgpt] ppo trainer use default generate args * [chatgpt] example remove generation preparing fn * [chatgpt] benchmark remove generation preparing fn * [chatgpt] fix cipull/2721/head
parent
21d6a48f4d
commit
9c0943ecdb
|
@ -34,6 +34,7 @@ jobs:
|
|||
|
||||
- name: Execute Examples
|
||||
run: |
|
||||
cd applications/ChatGPT
|
||||
./examples/test_ci.sh
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
||||
|
|
|
@ -35,6 +35,7 @@ jobs:
|
|||
|
||||
- name: Execute Unit Testing
|
||||
run: |
|
||||
cd applications/ChatGPT
|
||||
pytest tests/
|
||||
env:
|
||||
NCCL_SHM_DISABLE: 1
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import GPTActor, GPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
|
@ -151,8 +150,6 @@ def main(args):
|
|||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=gpt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from chatgpt.nn import OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
|
@ -144,8 +143,6 @@ def main(args):
|
|||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=opt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
|
|||
import torch.nn as nn
|
||||
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
|
||||
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
|
||||
from chatgpt.nn.generation_utils import update_model_kwargs_fn
|
||||
from chatgpt.replay_buffer import NaiveReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -59,6 +60,7 @@ class PPOTrainer(Trainer):
|
|||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
self._set_default_generate_kwargs(generate_kwargs, actor)
|
||||
actor = Actor(strategy.setup_model(actor.model))
|
||||
critic = strategy.setup_model(critic)
|
||||
reward_model = strategy.setup_model(reward_model)
|
||||
|
@ -102,3 +104,11 @@ class PPOTrainer(Trainer):
|
|||
self.critic_optim.zero_grad()
|
||||
|
||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||
|
||||
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
|
||||
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||
|
|
|
@ -3,12 +3,6 @@ from copy import deepcopy
|
|||
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import (
|
||||
bloom_prepare_inputs_fn,
|
||||
gpt_prepare_inputs_fn,
|
||||
opt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn,
|
||||
)
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
|
@ -66,36 +60,33 @@ def main(args):
|
|||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
prepare_inputs_fn = gpt_prepare_inputs_fn
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
prepare_inputs_fn = bloom_prepare_inputs_fn
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
prepare_inputs_fn = opt_prepare_inputs_fn
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn)
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
|
|
|
@ -3,7 +3,6 @@ from copy import deepcopy
|
|||
|
||||
import pandas as pd
|
||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
|
||||
from chatgpt.trainer import PPOTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
|
@ -70,24 +69,24 @@ def main(args):
|
|||
return {k: v.cuda() for k, v in batch.items()}
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
prepare_inputs_fn=gpt_prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn)
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer.fit(dataset,
|
||||
num_episodes=args.num_episodes,
|
||||
|
|
Loading…
Reference in New Issue