mirror of https://github.com/hpcaitech/ColossalAI
[chat] refactor actor class (#3968)
* refactor: separate log_probs fn from Actor forward fn * refactor: separate generate fn from Actor class * feat: update unwrap_model and get_base_model * unwrap_model returns model not wrapped by Strategy * get_base_model returns HF model for Actor, Critic and RewardModel * feat: simplify Strategy.prepare * style: remove get_base_model method of Actor * perf: tokenize text in batches * refactor: move calc_action_log_probs to utils of model * test: update test with new forward fn * style: rename forward fn args * fix: do not unwrap model in save_model fn of naive strategy * test: add gemini test for train_prompts * fix: fix _set_default_generate_kwargspull/3970/head
parent
b3ab7fbabf
commit
9d02590c9a
|
@ -35,14 +35,14 @@ class PromptDataset(Dataset):
|
|||
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
for data_dict in list_data_dict:
|
||||
token = tokenizer(data_dict["instruction"],
|
||||
return_tensors='pt',
|
||||
max_length=max_length,
|
||||
padding='max_length',
|
||||
truncation=True)
|
||||
for k, tensor in token.items():
|
||||
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
|
||||
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
||||
tokens = tokenizer(instructions,
|
||||
return_tensors='pt',
|
||||
max_length=max_length,
|
||||
padding='max_length',
|
||||
truncation=True)
|
||||
for k, tensor in tokens.items():
|
||||
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keyed_prompt["input_ids"])
|
||||
|
|
|
@ -74,21 +74,18 @@ class SFTDataset(Dataset):
|
|||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
|
||||
def _tokenize_fn(strings: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_length: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
tokenized_list = tokenizer(
|
||||
strings, return_tensors="pt", padding="longest",
|
||||
max_length=max_length, truncation=True
|
||||
)
|
||||
input_ids = labels = tokenized_list["input_ids"]
|
||||
input_ids_lens = labels_lens = \
|
||||
tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
|
@ -105,7 +102,10 @@ def preprocess(
|
|||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer, max_length)
|
||||
for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from coati.models.utils import compute_reward, normalize
|
||||
from coati.models.generation import generate_with_actor
|
||||
from coati.models.utils import calc_action_log_probs, compute_reward, normalize
|
||||
|
||||
from .base import Experience, ExperienceMaker
|
||||
|
||||
|
@ -16,13 +17,16 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||
self.initial_model.eval()
|
||||
self.reward_model.eval()
|
||||
|
||||
sequences, attention_mask, action_mask = self.actor.generate(input_ids,
|
||||
sequences, attention_mask, action_mask = generate_with_actor(self.actor,
|
||||
input_ids,
|
||||
return_action_mask=True,
|
||||
**generate_kwargs)
|
||||
num_actions = action_mask.size(1)
|
||||
|
||||
action_log_probs = self.actor(sequences, num_actions, attention_mask)
|
||||
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
|
||||
actor_output = self.actor(sequences, attention_mask)
|
||||
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
||||
base_model_output = self.initial_model(sequences, attention_mask)
|
||||
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
||||
value = self.critic(sequences, action_mask, attention_mask)
|
||||
r = self.reward_model(sequences, attention_mask)
|
||||
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .actor import Actor
|
||||
|
@ -5,10 +7,10 @@ from .critic import Critic
|
|||
from .reward_model import RewardModel
|
||||
|
||||
|
||||
def get_base_model(model: nn.Module) -> nn.Module:
|
||||
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
|
||||
"""Get the base model of our wrapper classes.
|
||||
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
|
||||
For Critic and RewardModel, it's base model is itself.
|
||||
For Actor, Critic and RewardModel, return ``model.model``,
|
||||
it's usually a ``transformers.PreTrainedModel``.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to get base model from
|
||||
|
@ -16,9 +18,9 @@ def get_base_model(model: nn.Module) -> nn.Module:
|
|||
Returns:
|
||||
nn.Module: the base model
|
||||
"""
|
||||
if isinstance(model, Actor):
|
||||
return model.get_base_model()
|
||||
return model
|
||||
assert isinstance(model, (Actor, Critic, RewardModel)), \
|
||||
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
|
||||
return model.model
|
||||
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..generation import generate
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import log_probs_from_logits
|
||||
|
||||
|
||||
class Actor(LoRAModule):
|
||||
|
@ -24,42 +21,16 @@ class Actor(LoRAModule):
|
|||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Returns model output.
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
def get_base_model(self):
|
||||
return self.model
|
||||
output = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**model_kwargs
|
||||
)
|
||||
return output
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
|
@ -55,9 +57,8 @@ def sample(model: nn.Module,
|
|||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
||||
'input_ids': input_ids
|
||||
}
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
|
||||
if prepare_inputs_fn is not None else {'input_ids': input_ids}
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
|
@ -144,3 +145,35 @@ def generate(model: nn.Module,
|
|||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_with_actor(actor_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
|
||||
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||
"""
|
||||
# generate sequences
|
||||
sequences = generate(actor_model, input_ids, **kwargs)
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
|
|
@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
|||
return log_probs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def calc_action_log_probs(output: torch.Tensor,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int
|
||||
) -> torch.Tensor:
|
||||
"""Calculate action log probs.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Output tensor of Actor.forward.
|
||||
sequences (torch.LongTensor): Input sequences.
|
||||
num_actions (int): Number of actions.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
tensor = tensor.sum(dim=dim)
|
||||
|
|
|
@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic
|
||||
from coati.models.base import Actor, Critic, get_base_model
|
||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
@ -165,7 +166,8 @@ class PPOTrainer(Trainer):
|
|||
self.critic.train()
|
||||
# policy loss
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
||||
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
|
@ -175,8 +177,8 @@ class PPOTrainer(Trainer):
|
|||
if self.ptx_coef != 0:
|
||||
batch = next(iter(self.pretrain_dataloader))
|
||||
batch = to_device(batch, self.device)
|
||||
ptx_log_probs = self.actor.get_base_model()(batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'])['logits']
|
||||
ptx_log_probs = self.actor(batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'])['logits']
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
|
||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||
|
||||
|
@ -200,14 +202,15 @@ class PPOTrainer(Trainer):
|
|||
return {'reward': experience.reward.mean().item()}
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = strategy.unwrap_model(actor)
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
||||
unwrapper_model = strategy.unwrap_model(actor)
|
||||
hf_model = get_base_model(unwrapper_model)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor, get_base_model
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -69,21 +68,16 @@ class Strategy(ABC):
|
|||
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
||||
"""
|
||||
|
||||
def prepare_model(model: nn.Module):
|
||||
if isinstance(model, Actor):
|
||||
return Actor(self.setup_model(model.get_base_model()))
|
||||
return self.setup_model(model)
|
||||
|
||||
rets = []
|
||||
for arg in models_or_model_optim_pairs:
|
||||
if isinstance(arg, tuple):
|
||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||
model, optimizer = arg
|
||||
model = prepare_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
|
||||
model = self.setup_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, model)
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, nn.Module):
|
||||
rets.append(prepare_model(arg))
|
||||
rets.append(self.setup_model(model))
|
||||
else:
|
||||
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
||||
|
||||
|
@ -93,16 +87,15 @@ class Strategy(ABC):
|
|||
|
||||
@staticmethod
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
|
||||
For Actor, it will unwrap `actor.model`.
|
||||
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model to unwrap
|
||||
|
||||
Returns:
|
||||
nn.Module: the original model (usually a huggingface model)
|
||||
nn.Module: the original model
|
||||
"""
|
||||
return get_base_model(model)
|
||||
return model
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
|
@ -133,4 +126,4 @@ class Strategy(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import get_base_model
|
||||
from torch.optim import Optimizer
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
|
@ -153,14 +152,13 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
|
||||
return
|
||||
base_model = get_base_model(model)
|
||||
if self.stage == 3:
|
||||
assert isinstance(base_model, ZeroDDP)
|
||||
assert isinstance(model, ZeroDDP)
|
||||
# for stage 3, state_dict() method should be called on every rank
|
||||
state_dict = base_model.state_dict(only_rank_0=only_rank0)
|
||||
state_dict = model.state_dict(only_rank_0=only_rank0)
|
||||
else:
|
||||
# only_rank0 is false or rank == 0
|
||||
state_dict = base_model.state_dict()
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
|
@ -172,11 +170,10 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
|
||||
if self.stage == 3:
|
||||
assert isinstance(base_model, ZeroDDP)
|
||||
return base_model.module
|
||||
return base_model
|
||||
assert isinstance(model, ZeroDDP)
|
||||
return model.module
|
||||
return model
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
|
@ -196,5 +193,5 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
# if isinstance(module, LoraLinear):
|
||||
# module.merge_weights = True
|
||||
# module.eval()
|
||||
base_model: ZeroDDP = get_base_model(model)
|
||||
yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
assert isinstance(model, ZeroDDP)
|
||||
yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
|
||||
|
|
|
@ -69,8 +69,8 @@ class DDPStrategy(NaiveStrategy):
|
|||
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
base_model: DDP = super().unwrap_model(model)
|
||||
return base_model.module
|
||||
assert isinstance(model, DDP)
|
||||
return model.module
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
|
|
|
@ -58,14 +58,13 @@ class NaiveStrategy(Strategy):
|
|||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
|
||||
base_model = get_base_model(model)
|
||||
state_dict = base_model.state_dict()
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
base_model = get_base_model(model)
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
base_model.load_state_dict(state_dict, strict=strict)
|
||||
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
|
|
@ -121,6 +121,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
|
|||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
|
@ -43,7 +44,8 @@ def run_test_checkpoint(strategy):
|
|||
def run_step():
|
||||
data = get_data(BATCH_SIZE)
|
||||
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
|
||||
action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
|
||||
actor_output = actor(data['input_ids'], data['attention_mask'])
|
||||
action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1))
|
||||
loss = action_log_probs.sum()
|
||||
strategy.backward(loss, actor, actor_optim)
|
||||
strategy.optimizer_step(actor_optim)
|
||||
|
|
Loading…
Reference in New Issue