[chat] refactor actor class (#3968)

* refactor: separate log_probs fn from Actor forward fn

* refactor: separate generate fn from Actor class

* feat: update unwrap_model and get_base_model
* unwrap_model returns model not wrapped by Strategy
* get_base_model returns HF model for Actor, Critic and RewardModel

* feat: simplify Strategy.prepare

* style: remove get_base_model method of Actor

* perf: tokenize text in batches

* refactor: move calc_action_log_probs to utils of model

* test: update test with new forward fn

* style: rename forward fn args

* fix: do not unwrap model in save_model fn of naive strategy

* test: add gemini test for train_prompts

* fix: fix _set_default_generate_kwargs
pull/3970/head
Wenhao Chen 2023-06-13 13:31:56 +08:00 committed by GitHub
parent b3ab7fbabf
commit 9d02590c9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 151 additions and 120 deletions

View File

@ -35,14 +35,14 @@ class PromptDataset(Dataset):
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
list_data_dict = list_data_dict[:max_datasets_size]
for data_dict in list_data_dict:
token = tokenizer(data_dict["instruction"],
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in token.items():
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
def __len__(self):
return len(self.keyed_prompt["input_ids"])

View File

@ -74,21 +74,18 @@ class SFTDataset(Dataset):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
max_length: int
) -> Dict[str, torch.Tensor]:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
tokenized_list = tokenizer(
strings, return_tensors="pt", padding="longest",
max_length=max_length, truncation=True
)
input_ids = labels = tokenized_list["input_ids"]
input_ids_lens = labels_lens = \
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
return dict(
input_ids=input_ids,
labels=labels,
@ -105,7 +102,10 @@ def preprocess(
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
examples_tokenized, sources_tokenized = [
_tokenize_fn(strings, tokenizer, max_length)
for strings in (examples, sources)
]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):

View File

@ -1,5 +1,6 @@
import torch
from coati.models.utils import compute_reward, normalize
from coati.models.generation import generate_with_actor
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
from .base import Experience, ExperienceMaker
@ -16,13 +17,16 @@ class NaiveExperienceMaker(ExperienceMaker):
self.initial_model.eval()
self.reward_model.eval()
sequences, attention_mask, action_mask = self.actor.generate(input_ids,
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
input_ids,
return_action_mask=True,
**generate_kwargs)
num_actions = action_mask.size(1)
action_log_probs = self.actor(sequences, num_actions, attention_mask)
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
actor_output = self.actor(sequences, attention_mask)
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(sequences, attention_mask)
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)

View File

@ -1,3 +1,5 @@
from typing import Union
import torch.nn as nn
from .actor import Actor
@ -5,10 +7,10 @@ from .critic import Critic
from .reward_model import RewardModel
def get_base_model(model: nn.Module) -> nn.Module:
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
For Critic and RewardModel, it's base model is itself.
For Actor, Critic and RewardModel, return ``model.model``,
it's usually a ``transformers.PreTrainedModel``.
Args:
model (nn.Module): model to get base model from
@ -16,9 +18,9 @@ def get_base_model(model: nn.Module) -> nn.Module:
Returns:
nn.Module: the base model
"""
if isinstance(model, Actor):
return model.get_base_model()
return model
assert isinstance(model, (Actor, Critic, RewardModel)), \
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
return model.model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']

View File

@ -1,12 +1,9 @@
from typing import Optional, Tuple, Union
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..generation import generate
from ..lora import LoRAModule
from ..utils import log_probs_from_logits
class Actor(LoRAModule):
@ -24,42 +21,16 @@ class Actor(LoRAModule):
self.model = model
self.convert_to_lora()
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs
) -> torch.Tensor:
"""Returns model output.
"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def get_base_model(self):
return self.model
output = self.model(
input_ids,
attention_mask=attention_mask,
**model_kwargs
)
return output

View File

@ -1,8 +1,10 @@
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
try:
from transformers.generation_logits_process import (
@ -55,9 +57,8 @@ def sample(model: nn.Module,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
'input_ids': input_ids
}
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
if prepare_inputs_fn is not None else {'input_ids': input_ids}
outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :]
@ -144,3 +145,35 @@ def generate(model: nn.Module,
raise NotImplementedError
else:
raise ValueError("Unsupported generation mode")
@torch.no_grad()
def generate_with_actor(actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
sequences = generate(actor_model, input_ids, **kwargs)
# calculate auxiliary tensors
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]

View File

@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return log_probs_labels.squeeze(-1)
def calc_action_log_probs(output: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int
) -> torch.Tensor:
"""Calculate action log probs.
Args:
output (torch.Tensor): Output tensor of Actor.forward.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
Returns:
torch.Tensor: Action log probs.
"""
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
tensor = tensor * mask
tensor = tensor.sum(dim=dim)

View File

@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn as nn
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic
from coati.models.base import Actor, Critic, get_base_model
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.replay_buffer import NaiveReplayBuffer
from torch import Tensor
from torch.optim import Optimizer
@ -165,7 +166,8 @@ class PPOTrainer(Trainer):
self.critic.train()
# policy loss
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
@ -175,8 +177,8 @@ class PPOTrainer(Trainer):
if self.ptx_coef != 0:
batch = next(iter(self.pretrain_dataloader))
batch = to_device(batch, self.device)
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_log_probs = self.actor(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
@ -200,14 +202,15 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()}
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy.unwrap_model(actor)
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
unwrapper_model = strategy.unwrap_model(actor)
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
return new_kwargs

View File

@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@ -69,21 +68,16 @@ class Strategy(ABC):
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
"""
def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(model)
rets = []
for arg in models_or_model_optim_pairs:
if isinstance(arg, tuple):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
model = self.setup_model(model)
optimizer = self.setup_optimizer(optimizer, model)
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
rets.append(self.setup_model(model))
else:
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
@ -93,16 +87,15 @@ class Strategy(ABC):
@staticmethod
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
Args:
model (nn.Module): the model to unwrap
Returns:
nn.Module: the original model (usually a huggingface model)
nn.Module: the original model
"""
return get_base_model(model)
return model
@abstractmethod
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
@ -133,4 +126,4 @@ class Strategy(ABC):
@abstractmethod
def get_model_state_dict_shard(self, model: nn.Module, **config):
pass
pass

View File

@ -5,7 +5,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from coati.models.base import get_base_model
from torch.optim import Optimizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
@ -153,14 +152,13 @@ class ColossalAIStrategy(DDPStrategy):
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
return
base_model = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
assert isinstance(model, ZeroDDP)
# for stage 3, state_dict() method should be called on every rank
state_dict = base_model.state_dict(only_rank_0=only_rank0)
state_dict = model.state_dict(only_rank_0=only_rank0)
else:
# only_rank0 is false or rank == 0
state_dict = base_model.state_dict()
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
@ -172,11 +170,10 @@ class ColossalAIStrategy(DDPStrategy):
torch.save(optimizer.state_dict(), path)
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
return base_model.module
return base_model
assert isinstance(model, ZeroDDP)
return model.module
return model
def save_pretrained(self,
model: nn.Module,
@ -196,5 +193,5 @@ class ColossalAIStrategy(DDPStrategy):
# if isinstance(module, LoraLinear):
# module.merge_weights = True
# module.eval()
base_model: ZeroDDP = get_base_model(model)
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
assert isinstance(model, ZeroDDP)
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)

View File

@ -69,8 +69,8 @@ class DDPStrategy(NaiveStrategy):
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: DDP = super().unwrap_model(model)
return base_model.module
assert isinstance(model, DDP)
return model.module
def save_pretrained(self,
model: nn.Module,

View File

@ -58,14 +58,13 @@ class NaiveStrategy(Strategy):
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
base_model = get_base_model(model)
state_dict = base_model.state_dict()
state_dict = model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
base_model = get_base_model(model)
unwrapped_model = self.unwrap_model(model)
state_dict = torch.load(path, map_location=map_location)
base_model.load_state_dict(state_dict, strict=strict)
unwrapped_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path)

View File

@ -121,6 +121,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_gpt.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt

View File

@ -6,6 +6,7 @@ import pytest
import torch
import torch.distributed as dist
from coati.models.gpt import GPTActor
from coati.models.utils import calc_action_log_probs
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
@ -43,7 +44,8 @@ def run_test_checkpoint(strategy):
def run_step():
data = get_data(BATCH_SIZE)
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
actor_output = actor(data['input_ids'], data['attention_mask'])
action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1))
loss = action_log_probs.sum()
strategy.backward(loss, actor, actor_optim)
strategy.optimizer_step(actor_optim)