mirror of https://github.com/hpcaitech/ColossalAI
[chat] refactor trainer class (#4080)
* to: add SLTrainer * refactor: refactor RMTrainer and SFTTrainer * fix: fix init file * feat: remove on_learn_epoch fn as not used * fix: align with modified gemini arguments * to: add OnPolicyTrainer * revert: add _on_learn_epoch fn * refactor: refactor PPOTrainer * style: rename PPOTrainer argument * fix: align with modified PPO arguments * test: align with modified train_prompts arguments * chore: modify train_prompts * docs: align with modified arguments * fix: remove unnecessary output * fix: move dataloader to fit fn of SLTrainer * fix: move dataloader to fit fn of OnPolicyTrainer * fix: modify usage of prompt and pretrain dataloaderpull/4122/head
parent
711e2b4c00
commit
b03d64d010
|
@ -83,7 +83,7 @@ More details can be found in the latest news.
|
|||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ColossalChat%20Speed.jpg" width=450/>
|
||||
</p>
|
||||
|
||||
> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --max_timesteps 1 --update_timesteps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32
|
||||
> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --num_collect_steps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32
|
||||
|
||||
## Install
|
||||
|
||||
|
|
|
@ -137,6 +137,12 @@ def main(args):
|
|||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||
dataloader = DataLoader(random_prompts,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=preprocess_batch)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
|
@ -145,7 +151,6 @@ def main(args):
|
|||
actor_optim,
|
||||
critic_optim,
|
||||
ptx_coef=0,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
offload_inference_models=args.offload_inference_models,
|
||||
max_length=512,
|
||||
|
@ -157,17 +162,11 @@ def main(args):
|
|||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||
dataloader = DataLoader(random_prompts,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=preprocess_batch)
|
||||
|
||||
trainer.fit(dataloader,
|
||||
None,
|
||||
trainer.fit(prompt_dataloader=dataloader,
|
||||
pretrain_dataloader=None,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
|
@ -183,9 +182,8 @@ if __name__ == '__main__':
|
|||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=8)
|
||||
parser.add_argument('--num_update_steps', type=int, default=1)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0)
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from .base import Trainer
|
||||
from .base import OnPolicyTrainer, SLTrainer
|
||||
from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
from .sft import SFTTrainer
|
||||
|
||||
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
|
||||
__all__ = [
|
||||
'SLTrainer', 'OnPolicyTrainer',
|
||||
'RewardModelTrainer', 'SFTTrainer',
|
||||
'PPOTrainer'
|
||||
]
|
||||
|
|
|
@ -1,54 +1,108 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import tqdm
|
||||
from coati.experience_maker import Experience
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import CycledDataLoader, is_rank_0
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
class SLTrainer(ABC):
|
||||
"""
|
||||
Base class for rlhf trainers.
|
||||
Base class for supervised learning trainers.
|
||||
|
||||
Args:
|
||||
strategy (Strategy):the strategy to use for training
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
model (nn.Module): the model to train
|
||||
optim (Optimizer): the optimizer to use for training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
max_epochs: int = 1,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
max_epochs: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.max_epochs = max_epochs
|
||||
self.generate_kwargs = generate_kwargs
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
@abstractmethod
|
||||
def _train(self, epoch):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def _eval(self, epoch):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _before_fit(self):
|
||||
self.no_epoch_bar = False
|
||||
|
||||
def fit(self, *args, **kwargs):
|
||||
self._before_fit(*args, **kwargs)
|
||||
for epoch in tqdm.trange(self.max_epochs,
|
||||
desc="Epochs",
|
||||
disable=not is_rank_0() or self.no_epoch_bar
|
||||
):
|
||||
self._train(epoch)
|
||||
self._eval(epoch)
|
||||
|
||||
|
||||
class OnPolicyTrainer(ABC):
|
||||
"""
|
||||
Base class for on-policy rl trainers, e.g. PPO.
|
||||
|
||||
Args:
|
||||
strategy (Strategy):the strategy to use for training
|
||||
buffer (NaiveReplayBuffer): the buffer to collect experiences
|
||||
sample_buffer (bool, defaults to False): whether to sample from buffer
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
buffer: NaiveReplayBuffer,
|
||||
sample_buffer: bool,
|
||||
dataloader_pin_memory: bool,
|
||||
callbacks: List[Callback] = []
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.buffer = buffer
|
||||
self.sample_buffer = sample_buffer
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
|
||||
# TODO(ver217): maybe simplify these code using context
|
||||
def _on_fit_start(self) -> None:
|
||||
@contextmanager
|
||||
def _fit_ctx(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_fit_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_episode_start(self, episode: int) -> None:
|
||||
@contextmanager
|
||||
def _episode_ctx(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_start(episode)
|
||||
|
||||
def _on_episode_end(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
||||
|
||||
def _on_make_experience_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
|
@ -73,3 +127,71 @@ class Trainer(ABC):
|
|||
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_batch_end(metrics, experience)
|
||||
|
||||
@abstractmethod
|
||||
def _make_experience(self, collect_step: int):
|
||||
"""
|
||||
Implement this method to make experience.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def _learn(self, update_step: int):
|
||||
"""
|
||||
Implement this method to learn from experience, either
|
||||
sample from buffer or transform buffer into dataloader.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _collect_phase(self, collect_step: int):
|
||||
self._on_make_experience_start()
|
||||
experience = self._make_experience(collect_step)
|
||||
self._on_make_experience_end(experience)
|
||||
self.buffer.append(experience)
|
||||
|
||||
def _update_phase(self, update_step: int):
|
||||
self._on_learn_epoch_start(update_step)
|
||||
self._learn(update_step)
|
||||
self._on_learn_epoch_end(update_step)
|
||||
|
||||
def fit(self,
|
||||
prompt_dataloader: DataLoader,
|
||||
pretrain_dataloader: DataLoader,
|
||||
num_episodes: int,
|
||||
num_collect_steps: int,
|
||||
num_update_steps: int,
|
||||
):
|
||||
"""
|
||||
The main training loop of on-policy rl trainers.
|
||||
|
||||
Args:
|
||||
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||
num_episodes (int): the number of episodes to train
|
||||
num_collect_steps (int): the number of collect steps per episode
|
||||
num_update_steps (int): the number of update steps per episode
|
||||
"""
|
||||
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
||||
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
|
||||
|
||||
with self._fit_ctx():
|
||||
for episode in tqdm.trange(num_episodes,
|
||||
desc="Episodes",
|
||||
disable=not is_rank_0()):
|
||||
with self._episode_ctx(episode):
|
||||
for collect_step in tqdm.trange(num_collect_steps,
|
||||
desc="Collect steps",
|
||||
disable=not is_rank_0()):
|
||||
self._collect_phase(collect_step)
|
||||
if not self.sample_buffer:
|
||||
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
||||
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
||||
# I only call strategy.setup_dataloader() to setup dataloader.
|
||||
self.dataloader = self.strategy.setup_dataloader(self.buffer,
|
||||
self.dataloader_pin_memory)
|
||||
for update_step in tqdm.trange(num_update_steps,
|
||||
desc="Update steps",
|
||||
disable=not is_rank_0()):
|
||||
self._update_phase(update_step)
|
||||
# NOTE: this is for on-policy algorithms
|
||||
self.buffer.clear()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, get_base_model
|
||||
|
@ -9,19 +8,32 @@ from coati.models.utils import calc_action_log_probs
|
|||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DistributedSampler
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import Trainer
|
||||
from .base import OnPolicyTrainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import ColossalAIStrategy, Strategy
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class PPOTrainer(Trainer):
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
||||
unwrapper_model = strategy.unwrap_model(actor)
|
||||
hf_model = get_base_model(unwrapper_model)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
||||
|
||||
|
||||
class PPOTrainer(OnPolicyTrainer):
|
||||
"""
|
||||
Trainer for PPO algorithm.
|
||||
|
||||
|
@ -35,14 +47,13 @@ class PPOTrainer(Trainer):
|
|||
critic_optim (Optimizer): the optimizer to use for critic model
|
||||
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
||||
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||
buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
|
||||
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
||||
buffer_limit (int, defaults to 0): the max_size limitation of buffer
|
||||
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
|
||||
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
||||
vf_coef (float, defaults to 1.0): the coefficient of value loss
|
||||
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
|
||||
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
||||
sample_buffer (bool, defaults to False): whether to sample from buffer
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
|
@ -65,25 +76,26 @@ class PPOTrainer(Trainer):
|
|||
eps_clip: float = 0.2,
|
||||
vf_coef: float = 1.0,
|
||||
value_clip: float = 0.4,
|
||||
max_epochs: int = 1,
|
||||
sample_replay_buffer: bool = False,
|
||||
sample_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
offload_inference_models: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
**generate_kwargs
|
||||
) -> None:
|
||||
if isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
|
||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs)
|
||||
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
super().__init__(
|
||||
strategy, buffer,
|
||||
sample_buffer, dataloader_pin_memory,
|
||||
callbacks
|
||||
)
|
||||
|
||||
self.experience_maker = experience_maker
|
||||
self.replay_buffer = replay_buffer
|
||||
self.sample_replay_buffer = sample_replay_buffer
|
||||
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
self.offload_inference_models = offload_inference_models
|
||||
|
||||
self.actor = actor
|
||||
|
@ -99,76 +111,20 @@ class PPOTrainer(Trainer):
|
|||
|
||||
self.device = get_current_device()
|
||||
|
||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||
if isinstance(inputs, Tensor):
|
||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||
elif isinstance(inputs, dict):
|
||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||||
def _make_experience(self, collect_step: int) -> Experience:
|
||||
prompts = self.prompt_dataloader.next()
|
||||
if self.offload_inference_models:
|
||||
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
||||
self.experience_maker.initial_model.to(self.device)
|
||||
self.experience_maker.reward_model.to(self.device)
|
||||
if isinstance(prompts, Tensor):
|
||||
return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
|
||||
elif isinstance(prompts, dict):
|
||||
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||
raise ValueError(f'Unsupported input type "{type(prompts)}"')
|
||||
|
||||
def _learn(self):
|
||||
# replay buffer may be empty at first, we should rebuild at each training
|
||||
if not self.sample_replay_buffer:
|
||||
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
||||
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
||||
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
||||
if self.sample_replay_buffer:
|
||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
for _ in pbar:
|
||||
experience = self.replay_buffer.sample()
|
||||
experience.to_device(self.device)
|
||||
metrics = self.training_step(experience)
|
||||
pbar.set_postfix(metrics)
|
||||
else:
|
||||
for epoch in range(self.max_epochs):
|
||||
self._on_learn_epoch_start(epoch)
|
||||
if isinstance(dataloader.sampler, DistributedSampler):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
|
||||
for experience in pbar:
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
metrics = self.training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
pbar.set_postfix(metrics)
|
||||
self._on_learn_epoch_end(epoch)
|
||||
|
||||
def fit(self,
|
||||
prompt_dataloader,
|
||||
pretrain_dataloader,
|
||||
num_episodes: int = 50000,
|
||||
max_timesteps: int = 500,
|
||||
update_timesteps: int = 5000) -> None:
|
||||
time = 0
|
||||
self.pretrain_dataloader = pretrain_dataloader
|
||||
self.prompt_dataloader = prompt_dataloader
|
||||
self._on_fit_start()
|
||||
for episode in range(num_episodes):
|
||||
self._on_episode_start(episode)
|
||||
for timestep in tqdm(range(max_timesteps),
|
||||
desc=f'Episode [{episode+1}/{num_episodes}]',
|
||||
disable=not is_rank_0()):
|
||||
time += 1
|
||||
prompts = next(iter(self.prompt_dataloader))
|
||||
self._on_make_experience_start()
|
||||
if self.offload_inference_models:
|
||||
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
||||
self.experience_maker.initial_model.to(self.device)
|
||||
self.experience_maker.reward_model.to(self.device)
|
||||
experience = self._make_experience(prompts)
|
||||
self._on_make_experience_end(experience)
|
||||
self.replay_buffer.append(experience)
|
||||
if time % update_timesteps == 0:
|
||||
if self.offload_inference_models:
|
||||
self.experience_maker.initial_model.to('cpu')
|
||||
self.experience_maker.reward_model.to('cpu')
|
||||
self._learn()
|
||||
self.replay_buffer.clear()
|
||||
self._on_episode_end(episode)
|
||||
self._on_fit_end()
|
||||
|
||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
def _training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
# policy loss
|
||||
|
@ -182,7 +138,7 @@ class PPOTrainer(Trainer):
|
|||
|
||||
# ptx loss
|
||||
if self.ptx_coef != 0:
|
||||
batch = next(iter(self.pretrain_dataloader))
|
||||
batch = self.pretrain_dataloader.next()
|
||||
batch = to_device(batch, self.device)
|
||||
ptx_log_probs = self.actor(batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'])['logits']
|
||||
|
@ -208,16 +164,29 @@ class PPOTrainer(Trainer):
|
|||
|
||||
return {'reward': experience.reward.mean().item()}
|
||||
|
||||
def _learn(self, update_step: int):
|
||||
if self.offload_inference_models:
|
||||
self.experience_maker.initial_model.to('cpu')
|
||||
self.experience_maker.reward_model.to('cpu')
|
||||
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
||||
unwrapper_model = strategy.unwrap_model(actor)
|
||||
hf_model = get_base_model(unwrapper_model)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
||||
# buffer may be empty at first, we should rebuild at each training
|
||||
if self.sample_buffer:
|
||||
experience = self.buffer.sample()
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
metrics = self._training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
else:
|
||||
if isinstance(self.dataloader.sampler, DistributedSampler):
|
||||
self.dataloader.sampler.set_epoch(update_step)
|
||||
pbar = tqdm(
|
||||
self.dataloader,
|
||||
desc=f'Train epoch [{update_step + 1}]',
|
||||
disable=not is_rank_0()
|
||||
)
|
||||
for experience in pbar:
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
metrics = self._training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
pbar.set_postfix(metrics)
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
from datetime import datetime
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .base import SLTrainer
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class RewardModelTrainer(Trainer):
|
||||
class RewardModelTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer to use while training reward model.
|
||||
|
||||
|
@ -24,12 +23,7 @@ class RewardModelTrainer(Trainer):
|
|||
optim (Optimizer): the optimizer to use for training
|
||||
lr_scheduler (_LRScheduler): the lr scheduler to use for training
|
||||
loss_fn (callable): the loss function to use for training
|
||||
train_dataloader (DataLoader): the dataloader to use for training
|
||||
valid_dataloader (DataLoader): the dataloader to use for validation
|
||||
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
||||
batch_size (int, defaults to 1): the batch size while training
|
||||
max_epochs (int, defaults to 2): the number of epochs to train
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -39,87 +33,79 @@ class RewardModelTrainer(Trainer):
|
|||
optim: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
loss_fn: Callable,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader,
|
||||
max_epochs: int = 1,
|
||||
callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
||||
super().__init__(strategy, max_epochs, model, optim)
|
||||
|
||||
self.loss_fn = loss_fn
|
||||
self.scheduler = lr_scheduler
|
||||
|
||||
def _eval(self, epoch):
|
||||
if self.eval_dataloader is not None:
|
||||
self.model.eval()
|
||||
dist, on, cnt = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
for i in range(len(chosen_reward)):
|
||||
cnt += 1
|
||||
if chosen_reward[i] > reject_reward[i]:
|
||||
on += 1
|
||||
dist += (chosen_reward - reject_reward).mean().item()
|
||||
self.dist = dist / len(self.eval_dataloader)
|
||||
self.acc = on / cnt
|
||||
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame(
|
||||
[[(epoch + 1) * len(self.train_dataloader),
|
||||
self.loss.item(), self.dist, self.acc]],
|
||||
columns=['step', 'loss', 'dist', 'acc']
|
||||
)
|
||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
||||
|
||||
def _train(self, epoch):
|
||||
self.model.train()
|
||||
step_bar = tqdm.trange(
|
||||
len(self.train_dataloader),
|
||||
desc='Train step of epoch %d' % epoch,
|
||||
disable=not is_rank_0()
|
||||
)
|
||||
cnt = 0
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
self.loss = self.loss_fn(chosen_reward, reject_reward)
|
||||
self.strategy.backward(self.loss, self.model, self.optimizer)
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
cnt += 1
|
||||
if cnt % 100 == 0:
|
||||
self.scheduler.step()
|
||||
step_bar.update()
|
||||
step_bar.close()
|
||||
|
||||
def _before_fit(self,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader):
|
||||
"""
|
||||
Args:
|
||||
train_dataloader (DataLoader): the dataloader to use for training
|
||||
valid_dataloader (DataLoader): the dataloader to use for validation
|
||||
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
||||
"""
|
||||
super()._before_fit()
|
||||
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.model = model
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optim
|
||||
self.scheduler = lr_scheduler
|
||||
|
||||
def eval_acc(self, dataloader):
|
||||
dist = 0
|
||||
on = 0
|
||||
cnt = 0
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
for i in range(len(chosen_reward)):
|
||||
cnt += 1
|
||||
if chosen_reward[i] > reject_reward[i]:
|
||||
on += 1
|
||||
dist += (chosen_reward - reject_reward).mean().item()
|
||||
dist_mean = dist / len(dataloader)
|
||||
acc = on / cnt
|
||||
self.model.train()
|
||||
return dist_mean, acc
|
||||
|
||||
def fit(self):
|
||||
time = datetime.now()
|
||||
epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
for epoch in range(self.max_epochs):
|
||||
step_bar = tqdm(range(self.train_dataloader.__len__()),
|
||||
desc='Train step of epoch %d' % epoch,
|
||||
disable=not is_rank_0())
|
||||
# train
|
||||
self.model.train()
|
||||
cnt = 0
|
||||
acc = 0
|
||||
dist = 0
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
cnt += 1
|
||||
if cnt == 100:
|
||||
self.scheduler.step()
|
||||
dist, acc = self.eval_acc(self.valid_dataloader)
|
||||
cnt = 0
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
|
||||
columns=['step', 'loss', 'dist', 'acc'])
|
||||
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
|
||||
step_bar.update()
|
||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||
|
||||
# eval
|
||||
dist, acc = self.eval_acc(self.eval_dataloader)
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
|
||||
columns=['step', 'loss', 'dist', 'acc'])
|
||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
||||
epoch_bar.update()
|
||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||
step_bar.close()
|
||||
|
|
|
@ -1,21 +1,22 @@
|
|||
import time
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import tqdm
|
||||
import wandb
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import SLTrainer
|
||||
from .strategies import ColossalAIStrategy, Strategy
|
||||
from .utils import is_rank_0, to_device
|
||||
|
||||
|
||||
class SFTTrainer(Trainer):
|
||||
class SFTTrainer(SLTrainer):
|
||||
"""
|
||||
Trainer to use while training reward model.
|
||||
|
||||
|
@ -23,12 +24,9 @@ class SFTTrainer(Trainer):
|
|||
model (torch.nn.Module): the model to train
|
||||
strategy (Strategy): the strategy to use for training
|
||||
optim(Optimizer): the optimizer to use for training
|
||||
train_dataloader: the dataloader to use for training
|
||||
eval_dataloader: the dataloader to use for evaluation
|
||||
batch_size (int, defaults to 1): the batch size while training
|
||||
lr_scheduler(_LRScheduler): the lr scheduler to use for training
|
||||
max_epochs (int, defaults to 2): the number of epochs to train
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
|
||||
accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -37,95 +35,92 @@ class SFTTrainer(Trainer):
|
|||
strategy: Strategy,
|
||||
optim: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader = None,
|
||||
max_epochs: int = 2,
|
||||
accumulation_steps: int = 8,
|
||||
callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
assert not isinstance(strategy.plugin, GeminiPlugin), \
|
||||
"Accumulation steps are not supported in stage 3 of ColossalAI"
|
||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
||||
self.train_dataloader = train_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
self.model = model
|
||||
self.optimizer = optim
|
||||
|
||||
super().__init__(strategy, max_epochs, model, optim)
|
||||
|
||||
self.accumulation_steps = accumulation_steps
|
||||
|
||||
self.scheduler = lr_scheduler
|
||||
|
||||
def fit(self, logger, use_wandb: bool = False):
|
||||
def _train(self, epoch: int):
|
||||
self.model.train()
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
|
||||
loss = outputs.loss
|
||||
loss = loss / self.accumulation_steps
|
||||
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
|
||||
self.total_loss += loss.item()
|
||||
|
||||
# gradient accumulation
|
||||
if (batch_id + 1) % self.accumulation_steps == 0:
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
if is_rank_0() and self.use_wandb:
|
||||
wandb.log({
|
||||
"loss": self.total_loss / self.accumulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
"epoch": epoch,
|
||||
"batch_id": batch_id
|
||||
})
|
||||
self.total_loss = 0
|
||||
self.step_bar.update()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
if self.eval_dataloader is not None:
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
loss_sum, num_seen = 0, 0
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
loss = outputs.loss
|
||||
|
||||
loss_sum += loss.item()
|
||||
num_seen += batch["input_ids"].size(0)
|
||||
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
|
||||
|
||||
def _before_fit(self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: Optional[DataLoader] = None,
|
||||
logger: Optional[DistributedLogger] = None,
|
||||
use_wandb: bool = False):
|
||||
"""
|
||||
Args:
|
||||
train_dataloader: the dataloader to use for training
|
||||
eval_dataloader: the dataloader to use for evaluation
|
||||
"""
|
||||
self.train_dataloader = train_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.logger = logger
|
||||
self.use_wandb = use_wandb
|
||||
if use_wandb:
|
||||
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
wandb.watch(self.model)
|
||||
total_loss = 0
|
||||
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
|
||||
step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
|
||||
desc=f'steps',
|
||||
disable=not is_rank_0())
|
||||
for epoch in range(self.max_epochs):
|
||||
|
||||
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
|
||||
# train
|
||||
self.model.train()
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
|
||||
loss = outputs.loss
|
||||
|
||||
if loss >= 2.5 and is_rank_0():
|
||||
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
|
||||
|
||||
loss = loss / self.accumulation_steps
|
||||
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# gradient accumulation
|
||||
if (batch_id + 1) % self.accumulation_steps == 0:
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
if is_rank_0() and use_wandb:
|
||||
wandb.log({
|
||||
"loss": total_loss / self.accumulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
"epoch": epoch,
|
||||
"batch_id": batch_id
|
||||
})
|
||||
total_loss = 0
|
||||
step_bar.update()
|
||||
|
||||
# if batch_id % log_interval == 0:
|
||||
# logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
|
||||
# wandb.log({"loss": loss.item()})
|
||||
|
||||
# process_bar.update()
|
||||
|
||||
# eval
|
||||
if self.eval_dataloader is not None:
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
loss_sum = 0
|
||||
num_seen = 0
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
loss = outputs.loss
|
||||
|
||||
loss_sum += loss.item()
|
||||
num_seen += batch["input_ids"].size(0)
|
||||
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
|
||||
|
||||
# epoch_bar.update()
|
||||
self.total_loss = 0
|
||||
self.no_epoch_bar = True
|
||||
self.step_bar = tqdm.trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
|
||||
desc=f'steps',
|
||||
disable=not is_rank_0()
|
||||
)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import functools
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
|
@ -103,7 +102,7 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
if stage == 3:
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
# gemini_config
|
||||
# gemini_config
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
precision=precision,
|
||||
|
@ -113,20 +112,20 @@ class ColossalAIStrategy(DDPStrategy):
|
|||
search_range_m=search_range_m,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_m=min_chunk_size_m,
|
||||
# zero_optim_config
|
||||
# zero_optim_config
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
# optim_config
|
||||
# optim_config
|
||||
**optim_kwargs)
|
||||
else:
|
||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||
# zero_config
|
||||
# zero_config
|
||||
stage=stage,
|
||||
precision=precision,
|
||||
# zero_optim_config
|
||||
# zero_optim_config
|
||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'),
|
||||
# optim_config
|
||||
# optim_config
|
||||
**optim_kwargs)
|
||||
|
||||
super().__init__(seed, plugin_initializer)
|
||||
|
|
|
@ -3,6 +3,33 @@ from typing import Any
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class CycledDataLoader:
|
||||
"""
|
||||
Why do we need this class?
|
||||
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
|
||||
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
|
||||
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataloader: DataLoader,
|
||||
) -> None:
|
||||
self.dataloader = dataloader
|
||||
|
||||
self.count = 0
|
||||
self.dataloader_iter = iter(dataloader)
|
||||
|
||||
def next(self):
|
||||
self.count += 1
|
||||
try:
|
||||
return next(self.dataloader_iter)
|
||||
except StopIteration:
|
||||
self.count = 0
|
||||
self.dataloader_iter = iter(self.dataloader)
|
||||
return next(self.dataloader_iter)
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
|
|
|
@ -171,9 +171,8 @@ Pretrain dataset: the pretrain dataset including the instruction and correspondi
|
|||
- --pretrain_dataset: path of the ptx dataset, type=str, default=None
|
||||
- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
|
||||
- --num_episodes: num of episodes for training, type=int, default=10
|
||||
- --max_epochs: max epochs for training in one episode, type=int, default=5
|
||||
- --max_timesteps: max episodes in one batch, type=int, default=10
|
||||
- --update_timesteps: timesteps to update, type=int, default=10
|
||||
- --num_update_steps: number of steps to update policy per episode, type=int
|
||||
- --num_collect_steps: number of steps to collect experience per episode, type=int
|
||||
- --train_batch_size: batch size while training, type=int, default=8
|
||||
- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
|
||||
- --experience_batch_size: batch size to make experience, type=int, default=8
|
||||
|
|
|
@ -171,7 +171,6 @@ def main(args):
|
|||
critic_optim,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
|
@ -186,8 +185,8 @@ def main(args):
|
|||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
|
@ -215,9 +214,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=2)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
|
|
|
@ -63,8 +63,8 @@ for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
|
|||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy $strategy --model $model \
|
||||
--num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2
|
||||
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
||||
--train_batch_size 2
|
||||
done
|
||||
done
|
||||
|
||||
|
@ -149,8 +149,8 @@ rm -rf ${BASE}/rm_ckpt.pt
|
|||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--strategy colossalai_zero2 --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'facebook/opt-350m' --model opt \
|
||||
--rm_pretrain 'facebook/opt-350m' \
|
||||
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
||||
|
@ -159,8 +159,8 @@ rm -rf ${BASE}/rm_ckpt_opt.pt
|
|||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--strategy colossalai_zero2 --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
|
@ -168,8 +168,8 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
|||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--strategy colossalai_gemini --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
|
|
|
@ -177,7 +177,6 @@ def main(args):
|
|||
critic_optim,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
max_length=args.max_seq_len,
|
||||
use_cache=True,
|
||||
|
@ -192,8 +191,8 @@ def main(args):
|
|||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
num_update_steps=args.num_update_steps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
||||
|
@ -220,9 +219,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
|
@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
|||
|
||||
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py \
|
||||
--pretrain_dataset /path/to/data.json \
|
||||
--prompt_dataset /path/to/data.json \
|
||||
--strategy colossalai_zero2 \
|
||||
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
||||
--train_batch_size 2
|
||||
|
|
|
@ -178,12 +178,11 @@ def train(args):
|
|||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
loss_fn=loss_fn,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
max_epochs=args.max_epochs)
|
||||
|
||||
trainer.fit()
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
eval_dataloader=eval_dataloader)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
|
|
|
@ -170,12 +170,13 @@ def train(args):
|
|||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps)
|
||||
|
||||
trainer.fit(logger=logger, use_wandb=args.use_wandb)
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
logger=logger,
|
||||
use_wandb=args.use_wandb)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
|
|
Loading…
Reference in New Issue