[chatgpt] optimize generation kwargs (#2717)

* [chatgpt] ppo trainer use default generate args

* [chatgpt] example remove generation preparing fn

* [chatgpt] benchmark remove generation preparing fn

* [chatgpt] fix ci
pull/2721/head
ver217 2023-02-15 13:59:58 +08:00 committed by GitHub
parent 21d6a48f4d
commit 9c0943ecdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 52 deletions

View File

@ -34,6 +34,7 @@ jobs:
- name: Execute Examples - name: Execute Examples
run: | run: |
cd applications/ChatGPT
./examples/test_ci.sh ./examples/test_ci.sh
env: env:
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1

View File

@ -35,6 +35,7 @@ jobs:
- name: Execute Unit Testing - name: Execute Unit Testing
run: | run: |
cd applications/ChatGPT
pytest tests/ pytest tests/
env: env:
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1

View File

@ -5,7 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import GPTActor, GPTCritic, RewardModel from chatgpt.nn import GPTActor, GPTCritic, RewardModel
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
@ -151,8 +150,6 @@ def main(args):
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=gpt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())

View File

@ -5,7 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from chatgpt.nn import OPTActor, OPTCritic, RewardModel from chatgpt.nn import OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import opt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.callbacks import PerformanceEvaluator from chatgpt.trainer.callbacks import PerformanceEvaluator
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
@ -144,8 +143,6 @@ def main(args):
top_k=50, top_k=50,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
prepare_inputs_fn=opt_prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())

View File

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
import torch.nn as nn import torch.nn as nn
from chatgpt.experience_maker import Experience, NaiveExperienceMaker from chatgpt.experience_maker import Experience, NaiveExperienceMaker
from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss from chatgpt.nn import Actor, Critic, PolicyLoss, ValueLoss
from chatgpt.nn.generation_utils import update_model_kwargs_fn
from chatgpt.replay_buffer import NaiveReplayBuffer from chatgpt.replay_buffer import NaiveReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
@ -59,6 +60,7 @@ class PPOTrainer(Trainer):
dataloader_pin_memory: bool = True, dataloader_pin_memory: bool = True,
callbacks: List[Callback] = [], callbacks: List[Callback] = [],
**generate_kwargs) -> None: **generate_kwargs) -> None:
self._set_default_generate_kwargs(generate_kwargs, actor)
actor = Actor(strategy.setup_model(actor.model)) actor = Actor(strategy.setup_model(actor.model))
critic = strategy.setup_model(critic) critic = strategy.setup_model(critic)
reward_model = strategy.setup_model(reward_model) reward_model = strategy.setup_model(reward_model)
@ -102,3 +104,11 @@ class PPOTrainer(Trainer):
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(actor.model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = actor.model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

View File

@ -3,12 +3,6 @@ from copy import deepcopy
import torch import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import (
bloom_prepare_inputs_fn,
gpt_prepare_inputs_fn,
opt_prepare_inputs_fn,
update_model_kwargs_fn,
)
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
@ -66,36 +60,33 @@ def main(args):
if args.model == 'gpt2': if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
prepare_inputs_fn = gpt_prepare_inputs_fn
elif args.model == 'bloom': elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
prepare_inputs_fn = bloom_prepare_inputs_fn
elif args.model == 'opt': elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
prepare_inputs_fn = opt_prepare_inputs_fn
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
# configure trainer # configure trainer
trainer = PPOTrainer(strategy, trainer = PPOTrainer(
actor, strategy,
critic, actor,
reward_model, critic,
initial_model, reward_model,
actor_optim, initial_model,
critic_optim, actor_optim,
max_epochs=args.max_epochs, critic_optim,
train_batch_size=args.train_batch_size, max_epochs=args.max_epochs,
tokenizer=preprocess_batch, train_batch_size=args.train_batch_size,
max_length=128, tokenizer=preprocess_batch,
do_sample=True, max_length=128,
temperature=1.0, do_sample=True,
top_k=50, temperature=1.0,
pad_token_id=tokenizer.pad_token_id, top_k=50,
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
prepare_inputs_fn=prepare_inputs_fn, eos_token_id=tokenizer.eos_token_id,
update_model_kwargs_fn=update_model_kwargs_fn) )
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
trainer.fit(random_prompts, trainer.fit(random_prompts,

View File

@ -3,7 +3,6 @@ from copy import deepcopy
import pandas as pd import pandas as pd
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.nn.generation_utils import gpt_prepare_inputs_fn, update_model_kwargs_fn
from chatgpt.trainer import PPOTrainer from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam from torch.optim import Adam
@ -70,24 +69,24 @@ def main(args):
return {k: v.cuda() for k, v in batch.items()} return {k: v.cuda() for k, v in batch.items()}
# configure trainer # configure trainer
trainer = PPOTrainer(strategy, trainer = PPOTrainer(
actor, strategy,
critic, actor,
reward_model, critic,
initial_model, reward_model,
actor_optim, initial_model,
critic_optim, actor_optim,
max_epochs=args.max_epochs, critic_optim,
train_batch_size=args.train_batch_size, max_epochs=args.max_epochs,
tokenizer=tokenize_fn, train_batch_size=args.train_batch_size,
max_length=128, tokenizer=tokenize_fn,
do_sample=True, max_length=128,
temperature=1.0, do_sample=True,
top_k=50, temperature=1.0,
pad_token_id=tokenizer.pad_token_id, top_k=50,
eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
prepare_inputs_fn=gpt_prepare_inputs_fn, eos_token_id=tokenizer.eos_token_id,
update_model_kwargs_fn=update_model_kwargs_fn) )
trainer.fit(dataset, trainer.fit(dataset,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,