mirror of https://github.com/hpcaitech/ColossalAI
[chat] refactor trainer class (#4080)
* to: add SLTrainer * refactor: refactor RMTrainer and SFTTrainer * fix: fix init file * feat: remove on_learn_epoch fn as not used * fix: align with modified gemini arguments * to: add OnPolicyTrainer * revert: add _on_learn_epoch fn * refactor: refactor PPOTrainer * style: rename PPOTrainer argument * fix: align with modified PPO arguments * test: align with modified train_prompts arguments * chore: modify train_prompts * docs: align with modified arguments * fix: remove unnecessary output * fix: move dataloader to fit fn of SLTrainer * fix: move dataloader to fit fn of OnPolicyTrainer * fix: modify usage of prompt and pretrain dataloaderpull/4122/head
parent
711e2b4c00
commit
b03d64d010
|
@ -83,7 +83,7 @@ More details can be found in the latest news.
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ColossalChat%20Speed.jpg" width=450/>
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/ColossalChat%20Speed.jpg" width=450/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --max_timesteps 1 --update_timesteps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32
|
> DeepSpeedChat performance comes from its blog on 2023 April 12, ColossalChat performance can be reproduced on an AWS p4d.24xlarge node with 8 A100-40G GPUs with the following command: torchrun --standalone --nproc_per_node 8 benchmark_opt_lora_dummy.py --num_collect_steps 1 --use_kernels --strategy colossalai_zero2 --experience_batch_size 64 --train_batch_size 32
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
|
|
|
@ -137,6 +137,12 @@ def main(args):
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||||
|
|
||||||
|
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||||
|
dataloader = DataLoader(random_prompts,
|
||||||
|
batch_size=args.experience_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=preprocess_batch)
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
trainer = PPOTrainer(strategy,
|
||||||
actor,
|
actor,
|
||||||
critic,
|
critic,
|
||||||
|
@ -145,7 +151,6 @@ def main(args):
|
||||||
actor_optim,
|
actor_optim,
|
||||||
critic_optim,
|
critic_optim,
|
||||||
ptx_coef=0,
|
ptx_coef=0,
|
||||||
max_epochs=args.max_epochs,
|
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
offload_inference_models=args.offload_inference_models,
|
offload_inference_models=args.offload_inference_models,
|
||||||
max_length=512,
|
max_length=512,
|
||||||
|
@ -157,17 +162,11 @@ def main(args):
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
callbacks=[performance_evaluator])
|
callbacks=[performance_evaluator])
|
||||||
|
|
||||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
trainer.fit(prompt_dataloader=dataloader,
|
||||||
dataloader = DataLoader(random_prompts,
|
pretrain_dataloader=None,
|
||||||
batch_size=args.experience_batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
collate_fn=preprocess_batch)
|
|
||||||
|
|
||||||
trainer.fit(dataloader,
|
|
||||||
None,
|
|
||||||
num_episodes=args.num_episodes,
|
num_episodes=args.num_episodes,
|
||||||
max_timesteps=args.max_timesteps,
|
num_update_steps=args.num_update_steps,
|
||||||
update_timesteps=args.update_timesteps)
|
num_collect_steps=args.num_collect_steps)
|
||||||
|
|
||||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||||
|
|
||||||
|
@ -183,9 +182,8 @@ if __name__ == '__main__':
|
||||||
],
|
],
|
||||||
default='ddp')
|
default='ddp')
|
||||||
parser.add_argument('--num_episodes', type=int, default=3)
|
parser.add_argument('--num_episodes', type=int, default=3)
|
||||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
parser.add_argument('--num_collect_steps', type=int, default=8)
|
||||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
parser.add_argument('--num_update_steps', type=int, default=1)
|
||||||
parser.add_argument('--max_epochs', type=int, default=1)
|
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
parser.add_argument('--lora_rank', type=int, default=0)
|
parser.add_argument('--lora_rank', type=int, default=0)
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
from .base import Trainer
|
from .base import OnPolicyTrainer, SLTrainer
|
||||||
from .ppo import PPOTrainer
|
from .ppo import PPOTrainer
|
||||||
from .rm import RewardModelTrainer
|
from .rm import RewardModelTrainer
|
||||||
from .sft import SFTTrainer
|
from .sft import SFTTrainer
|
||||||
|
|
||||||
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
|
__all__ = [
|
||||||
|
'SLTrainer', 'OnPolicyTrainer',
|
||||||
|
'RewardModelTrainer', 'SFTTrainer',
|
||||||
|
'PPOTrainer'
|
||||||
|
]
|
||||||
|
|
|
@ -1,54 +1,108 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from contextlib import contextmanager
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch.nn as nn
|
||||||
|
import tqdm
|
||||||
from coati.experience_maker import Experience
|
from coati.experience_maker import Experience
|
||||||
|
from coati.replay_buffer import NaiveReplayBuffer
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .callbacks import Callback
|
from .callbacks import Callback
|
||||||
from .strategies import Strategy
|
from .strategies import Strategy
|
||||||
|
from .utils import CycledDataLoader, is_rank_0
|
||||||
|
|
||||||
|
|
||||||
class Trainer(ABC):
|
class SLTrainer(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for rlhf trainers.
|
Base class for supervised learning trainers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strategy (Strategy):the strategy to use for training
|
strategy (Strategy):the strategy to use for training
|
||||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
model (nn.Module): the model to train
|
||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
optim (Optimizer): the optimizer to use for training
|
||||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
strategy: Strategy,
|
strategy: Strategy,
|
||||||
max_epochs: int = 1,
|
max_epochs: int,
|
||||||
dataloader_pin_memory: bool = True,
|
model: nn.Module,
|
||||||
callbacks: List[Callback] = [],
|
optimizer: Optimizer,
|
||||||
**generate_kwargs) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.max_epochs = max_epochs
|
self.max_epochs = max_epochs
|
||||||
self.generate_kwargs = generate_kwargs
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _train(self, epoch):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _eval(self, epoch):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _before_fit(self):
|
||||||
|
self.no_epoch_bar = False
|
||||||
|
|
||||||
|
def fit(self, *args, **kwargs):
|
||||||
|
self._before_fit(*args, **kwargs)
|
||||||
|
for epoch in tqdm.trange(self.max_epochs,
|
||||||
|
desc="Epochs",
|
||||||
|
disable=not is_rank_0() or self.no_epoch_bar
|
||||||
|
):
|
||||||
|
self._train(epoch)
|
||||||
|
self._eval(epoch)
|
||||||
|
|
||||||
|
|
||||||
|
class OnPolicyTrainer(ABC):
|
||||||
|
"""
|
||||||
|
Base class for on-policy rl trainers, e.g. PPO.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy (Strategy):the strategy to use for training
|
||||||
|
buffer (NaiveReplayBuffer): the buffer to collect experiences
|
||||||
|
sample_buffer (bool, defaults to False): whether to sample from buffer
|
||||||
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||||
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
strategy: Strategy,
|
||||||
|
buffer: NaiveReplayBuffer,
|
||||||
|
sample_buffer: bool,
|
||||||
|
dataloader_pin_memory: bool,
|
||||||
|
callbacks: List[Callback] = []
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.strategy = strategy
|
||||||
|
self.buffer = buffer
|
||||||
|
self.sample_buffer = sample_buffer
|
||||||
self.dataloader_pin_memory = dataloader_pin_memory
|
self.dataloader_pin_memory = dataloader_pin_memory
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
|
|
||||||
# TODO(ver217): maybe simplify these code using context
|
@contextmanager
|
||||||
def _on_fit_start(self) -> None:
|
def _fit_ctx(self) -> None:
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.on_fit_start()
|
callback.on_fit_start()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_fit_end()
|
||||||
|
|
||||||
def _on_fit_end(self) -> None:
|
@contextmanager
|
||||||
for callback in self.callbacks:
|
def _episode_ctx(self, episode: int) -> None:
|
||||||
callback.on_fit_end()
|
|
||||||
|
|
||||||
def _on_episode_start(self, episode: int) -> None:
|
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.on_episode_start(episode)
|
callback.on_episode_start(episode)
|
||||||
|
try:
|
||||||
def _on_episode_end(self, episode: int) -> None:
|
yield
|
||||||
for callback in self.callbacks:
|
finally:
|
||||||
callback.on_episode_end(episode)
|
for callback in self.callbacks:
|
||||||
|
callback.on_episode_end(episode)
|
||||||
|
|
||||||
def _on_make_experience_start(self) -> None:
|
def _on_make_experience_start(self) -> None:
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
@ -73,3 +127,71 @@ class Trainer(ABC):
|
||||||
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.on_learn_batch_end(metrics, experience)
|
callback.on_learn_batch_end(metrics, experience)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _make_experience(self, collect_step: int):
|
||||||
|
"""
|
||||||
|
Implement this method to make experience.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _learn(self, update_step: int):
|
||||||
|
"""
|
||||||
|
Implement this method to learn from experience, either
|
||||||
|
sample from buffer or transform buffer into dataloader.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _collect_phase(self, collect_step: int):
|
||||||
|
self._on_make_experience_start()
|
||||||
|
experience = self._make_experience(collect_step)
|
||||||
|
self._on_make_experience_end(experience)
|
||||||
|
self.buffer.append(experience)
|
||||||
|
|
||||||
|
def _update_phase(self, update_step: int):
|
||||||
|
self._on_learn_epoch_start(update_step)
|
||||||
|
self._learn(update_step)
|
||||||
|
self._on_learn_epoch_end(update_step)
|
||||||
|
|
||||||
|
def fit(self,
|
||||||
|
prompt_dataloader: DataLoader,
|
||||||
|
pretrain_dataloader: DataLoader,
|
||||||
|
num_episodes: int,
|
||||||
|
num_collect_steps: int,
|
||||||
|
num_update_steps: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
The main training loop of on-policy rl trainers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||||
|
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||||
|
num_episodes (int): the number of episodes to train
|
||||||
|
num_collect_steps (int): the number of collect steps per episode
|
||||||
|
num_update_steps (int): the number of update steps per episode
|
||||||
|
"""
|
||||||
|
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
||||||
|
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
|
||||||
|
|
||||||
|
with self._fit_ctx():
|
||||||
|
for episode in tqdm.trange(num_episodes,
|
||||||
|
desc="Episodes",
|
||||||
|
disable=not is_rank_0()):
|
||||||
|
with self._episode_ctx(episode):
|
||||||
|
for collect_step in tqdm.trange(num_collect_steps,
|
||||||
|
desc="Collect steps",
|
||||||
|
disable=not is_rank_0()):
|
||||||
|
self._collect_phase(collect_step)
|
||||||
|
if not self.sample_buffer:
|
||||||
|
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
||||||
|
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
||||||
|
# I only call strategy.setup_dataloader() to setup dataloader.
|
||||||
|
self.dataloader = self.strategy.setup_dataloader(self.buffer,
|
||||||
|
self.dataloader_pin_memory)
|
||||||
|
for update_step in tqdm.trange(num_update_steps,
|
||||||
|
desc="Update steps",
|
||||||
|
disable=not is_rank_0()):
|
||||||
|
self._update_phase(update_step)
|
||||||
|
# NOTE: this is for on-policy algorithms
|
||||||
|
self.buffer.clear()
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||||
from coati.models.base import Actor, Critic, get_base_model
|
from coati.models.base import Actor, Critic, get_base_model
|
||||||
|
@ -9,19 +8,32 @@ from coati.models.utils import calc_action_log_probs
|
||||||
from coati.replay_buffer import NaiveReplayBuffer
|
from coati.replay_buffer import NaiveReplayBuffer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .base import Trainer
|
from .base import OnPolicyTrainer
|
||||||
from .callbacks import Callback
|
from .callbacks import Callback
|
||||||
from .strategies import ColossalAIStrategy, Strategy
|
from .strategies import ColossalAIStrategy, Strategy
|
||||||
from .utils import is_rank_0, to_device
|
from .utils import is_rank_0, to_device
|
||||||
|
|
||||||
|
|
||||||
class PPOTrainer(Trainer):
|
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
||||||
|
unwrapper_model = strategy.unwrap_model(actor)
|
||||||
|
hf_model = get_base_model(unwrapper_model)
|
||||||
|
new_kwargs = {**generate_kwargs}
|
||||||
|
# use huggingface models method directly
|
||||||
|
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
||||||
|
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
||||||
|
|
||||||
|
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
||||||
|
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
||||||
|
|
||||||
|
return new_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class PPOTrainer(OnPolicyTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer for PPO algorithm.
|
Trainer for PPO algorithm.
|
||||||
|
|
||||||
|
@ -35,14 +47,13 @@ class PPOTrainer(Trainer):
|
||||||
critic_optim (Optimizer): the optimizer to use for critic model
|
critic_optim (Optimizer): the optimizer to use for critic model
|
||||||
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
||||||
train_batch_size (int, defaults to 8): the batch size to use for training
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||||
buffer_limit (int, defaults to 0): the max_size limitation of replay buffer
|
buffer_limit (int, defaults to 0): the max_size limitation of buffer
|
||||||
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
|
||||||
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
||||||
vf_coef (float, defaults to 1.0): the coefficient of value loss
|
vf_coef (float, defaults to 1.0): the coefficient of value loss
|
||||||
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
|
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
|
||||||
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
sample_buffer (bool, defaults to False): whether to sample from buffer
|
||||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
|
||||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||||
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
|
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
|
||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
|
@ -65,25 +76,26 @@ class PPOTrainer(Trainer):
|
||||||
eps_clip: float = 0.2,
|
eps_clip: float = 0.2,
|
||||||
vf_coef: float = 1.0,
|
vf_coef: float = 1.0,
|
||||||
value_clip: float = 0.4,
|
value_clip: float = 0.4,
|
||||||
max_epochs: int = 1,
|
sample_buffer: bool = False,
|
||||||
sample_replay_buffer: bool = False,
|
|
||||||
dataloader_pin_memory: bool = True,
|
dataloader_pin_memory: bool = True,
|
||||||
offload_inference_models: bool = True,
|
offload_inference_models: bool = True,
|
||||||
callbacks: List[Callback] = [],
|
callbacks: List[Callback] = [],
|
||||||
**generate_kwargs) -> None:
|
**generate_kwargs
|
||||||
|
) -> None:
|
||||||
if isinstance(strategy, ColossalAIStrategy):
|
if isinstance(strategy, ColossalAIStrategy):
|
||||||
from colossalai.booster.plugin import GeminiPlugin
|
from colossalai.booster.plugin import GeminiPlugin
|
||||||
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
|
assert not (isinstance(strategy.plugin, GeminiPlugin) and offload_inference_models), \
|
||||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||||
|
|
||||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
super().__init__(
|
||||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
strategy, buffer,
|
||||||
super().__init__(strategy, max_epochs, dataloader_pin_memory, callbacks, **generate_kwargs)
|
sample_buffer, dataloader_pin_memory,
|
||||||
|
callbacks
|
||||||
|
)
|
||||||
|
|
||||||
self.experience_maker = experience_maker
|
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||||
self.replay_buffer = replay_buffer
|
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||||
self.sample_replay_buffer = sample_replay_buffer
|
|
||||||
self.offload_inference_models = offload_inference_models
|
self.offload_inference_models = offload_inference_models
|
||||||
|
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
|
@ -99,76 +111,20 @@ class PPOTrainer(Trainer):
|
||||||
|
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
|
|
||||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
def _make_experience(self, collect_step: int) -> Experience:
|
||||||
if isinstance(inputs, Tensor):
|
prompts = self.prompt_dataloader.next()
|
||||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
if self.offload_inference_models:
|
||||||
elif isinstance(inputs, dict):
|
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
||||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
self.experience_maker.initial_model.to(self.device)
|
||||||
|
self.experience_maker.reward_model.to(self.device)
|
||||||
|
if isinstance(prompts, Tensor):
|
||||||
|
return self.experience_maker.make_experience(prompts, **self.generate_kwargs)
|
||||||
|
elif isinstance(prompts, dict):
|
||||||
|
return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
raise ValueError(f'Unsupported input type "{type(prompts)}"')
|
||||||
|
|
||||||
def _learn(self):
|
def _training_step(self, experience: Experience) -> Dict[str, float]:
|
||||||
# replay buffer may be empty at first, we should rebuild at each training
|
|
||||||
if not self.sample_replay_buffer:
|
|
||||||
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
|
||||||
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
|
||||||
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
|
||||||
if self.sample_replay_buffer:
|
|
||||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
|
||||||
for _ in pbar:
|
|
||||||
experience = self.replay_buffer.sample()
|
|
||||||
experience.to_device(self.device)
|
|
||||||
metrics = self.training_step(experience)
|
|
||||||
pbar.set_postfix(metrics)
|
|
||||||
else:
|
|
||||||
for epoch in range(self.max_epochs):
|
|
||||||
self._on_learn_epoch_start(epoch)
|
|
||||||
if isinstance(dataloader.sampler, DistributedSampler):
|
|
||||||
dataloader.sampler.set_epoch(epoch)
|
|
||||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
|
|
||||||
for experience in pbar:
|
|
||||||
self._on_learn_batch_start()
|
|
||||||
experience.to_device(self.device)
|
|
||||||
metrics = self.training_step(experience)
|
|
||||||
self._on_learn_batch_end(metrics, experience)
|
|
||||||
pbar.set_postfix(metrics)
|
|
||||||
self._on_learn_epoch_end(epoch)
|
|
||||||
|
|
||||||
def fit(self,
|
|
||||||
prompt_dataloader,
|
|
||||||
pretrain_dataloader,
|
|
||||||
num_episodes: int = 50000,
|
|
||||||
max_timesteps: int = 500,
|
|
||||||
update_timesteps: int = 5000) -> None:
|
|
||||||
time = 0
|
|
||||||
self.pretrain_dataloader = pretrain_dataloader
|
|
||||||
self.prompt_dataloader = prompt_dataloader
|
|
||||||
self._on_fit_start()
|
|
||||||
for episode in range(num_episodes):
|
|
||||||
self._on_episode_start(episode)
|
|
||||||
for timestep in tqdm(range(max_timesteps),
|
|
||||||
desc=f'Episode [{episode+1}/{num_episodes}]',
|
|
||||||
disable=not is_rank_0()):
|
|
||||||
time += 1
|
|
||||||
prompts = next(iter(self.prompt_dataloader))
|
|
||||||
self._on_make_experience_start()
|
|
||||||
if self.offload_inference_models:
|
|
||||||
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
|
||||||
self.experience_maker.initial_model.to(self.device)
|
|
||||||
self.experience_maker.reward_model.to(self.device)
|
|
||||||
experience = self._make_experience(prompts)
|
|
||||||
self._on_make_experience_end(experience)
|
|
||||||
self.replay_buffer.append(experience)
|
|
||||||
if time % update_timesteps == 0:
|
|
||||||
if self.offload_inference_models:
|
|
||||||
self.experience_maker.initial_model.to('cpu')
|
|
||||||
self.experience_maker.reward_model.to('cpu')
|
|
||||||
self._learn()
|
|
||||||
self.replay_buffer.clear()
|
|
||||||
self._on_episode_end(episode)
|
|
||||||
self._on_fit_end()
|
|
||||||
|
|
||||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
self.critic.train()
|
self.critic.train()
|
||||||
# policy loss
|
# policy loss
|
||||||
|
@ -182,7 +138,7 @@ class PPOTrainer(Trainer):
|
||||||
|
|
||||||
# ptx loss
|
# ptx loss
|
||||||
if self.ptx_coef != 0:
|
if self.ptx_coef != 0:
|
||||||
batch = next(iter(self.pretrain_dataloader))
|
batch = self.pretrain_dataloader.next()
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
ptx_log_probs = self.actor(batch['input_ids'],
|
ptx_log_probs = self.actor(batch['input_ids'],
|
||||||
attention_mask=batch['attention_mask'])['logits']
|
attention_mask=batch['attention_mask'])['logits']
|
||||||
|
@ -208,16 +164,29 @@ class PPOTrainer(Trainer):
|
||||||
|
|
||||||
return {'reward': experience.reward.mean().item()}
|
return {'reward': experience.reward.mean().item()}
|
||||||
|
|
||||||
|
def _learn(self, update_step: int):
|
||||||
|
if self.offload_inference_models:
|
||||||
|
self.experience_maker.initial_model.to('cpu')
|
||||||
|
self.experience_maker.reward_model.to('cpu')
|
||||||
|
|
||||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
|
# buffer may be empty at first, we should rebuild at each training
|
||||||
unwrapper_model = strategy.unwrap_model(actor)
|
if self.sample_buffer:
|
||||||
hf_model = get_base_model(unwrapper_model)
|
experience = self.buffer.sample()
|
||||||
new_kwargs = {**generate_kwargs}
|
self._on_learn_batch_start()
|
||||||
# use huggingface models method directly
|
experience.to_device(self.device)
|
||||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
metrics = self._training_step(experience)
|
||||||
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
self._on_learn_batch_end(metrics, experience)
|
||||||
|
else:
|
||||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
if isinstance(self.dataloader.sampler, DistributedSampler):
|
||||||
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
self.dataloader.sampler.set_epoch(update_step)
|
||||||
|
pbar = tqdm(
|
||||||
return new_kwargs
|
self.dataloader,
|
||||||
|
desc=f'Train epoch [{update_step + 1}]',
|
||||||
|
disable=not is_rank_0()
|
||||||
|
)
|
||||||
|
for experience in pbar:
|
||||||
|
self._on_learn_batch_start()
|
||||||
|
experience.to_device(self.device)
|
||||||
|
metrics = self._training_step(experience)
|
||||||
|
self._on_learn_batch_end(metrics, experience)
|
||||||
|
pbar.set_postfix(metrics)
|
||||||
|
|
|
@ -1,20 +1,19 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, List
|
from typing import Callable
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .base import Trainer
|
from .base import SLTrainer
|
||||||
from .callbacks import Callback
|
|
||||||
from .strategies import Strategy
|
from .strategies import Strategy
|
||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
|
|
||||||
class RewardModelTrainer(Trainer):
|
class RewardModelTrainer(SLTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer to use while training reward model.
|
Trainer to use while training reward model.
|
||||||
|
|
||||||
|
@ -24,12 +23,7 @@ class RewardModelTrainer(Trainer):
|
||||||
optim (Optimizer): the optimizer to use for training
|
optim (Optimizer): the optimizer to use for training
|
||||||
lr_scheduler (_LRScheduler): the lr scheduler to use for training
|
lr_scheduler (_LRScheduler): the lr scheduler to use for training
|
||||||
loss_fn (callable): the loss function to use for training
|
loss_fn (callable): the loss function to use for training
|
||||||
train_dataloader (DataLoader): the dataloader to use for training
|
|
||||||
valid_dataloader (DataLoader): the dataloader to use for validation
|
|
||||||
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
|
||||||
batch_size (int, defaults to 1): the batch size while training
|
|
||||||
max_epochs (int, defaults to 2): the number of epochs to train
|
max_epochs (int, defaults to 2): the number of epochs to train
|
||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -39,87 +33,79 @@ class RewardModelTrainer(Trainer):
|
||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
lr_scheduler: _LRScheduler,
|
lr_scheduler: _LRScheduler,
|
||||||
loss_fn: Callable,
|
loss_fn: Callable,
|
||||||
train_dataloader: DataLoader,
|
|
||||||
valid_dataloader: DataLoader,
|
|
||||||
eval_dataloader: DataLoader,
|
|
||||||
max_epochs: int = 1,
|
max_epochs: int = 1,
|
||||||
callbacks: List[Callback] = [],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
super().__init__(strategy, max_epochs, model, optim)
|
||||||
|
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.scheduler = lr_scheduler
|
||||||
|
|
||||||
|
def _eval(self, epoch):
|
||||||
|
if self.eval_dataloader is not None:
|
||||||
|
self.model.eval()
|
||||||
|
dist, on, cnt = 0, 0, 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
|
||||||
|
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||||
|
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||||
|
for i in range(len(chosen_reward)):
|
||||||
|
cnt += 1
|
||||||
|
if chosen_reward[i] > reject_reward[i]:
|
||||||
|
on += 1
|
||||||
|
dist += (chosen_reward - reject_reward).mean().item()
|
||||||
|
self.dist = dist / len(self.eval_dataloader)
|
||||||
|
self.acc = on / cnt
|
||||||
|
|
||||||
|
if is_rank_0():
|
||||||
|
log = pd.DataFrame(
|
||||||
|
[[(epoch + 1) * len(self.train_dataloader),
|
||||||
|
self.loss.item(), self.dist, self.acc]],
|
||||||
|
columns=['step', 'loss', 'dist', 'acc']
|
||||||
|
)
|
||||||
|
log.to_csv('log.csv', mode='a', header=False, index=False)
|
||||||
|
|
||||||
|
def _train(self, epoch):
|
||||||
|
self.model.train()
|
||||||
|
step_bar = tqdm.trange(
|
||||||
|
len(self.train_dataloader),
|
||||||
|
desc='Train step of epoch %d' % epoch,
|
||||||
|
disable=not is_rank_0()
|
||||||
|
)
|
||||||
|
cnt = 0
|
||||||
|
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||||
|
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||||
|
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||||
|
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||||
|
self.loss = self.loss_fn(chosen_reward, reject_reward)
|
||||||
|
self.strategy.backward(self.loss, self.model, self.optimizer)
|
||||||
|
self.strategy.optimizer_step(self.optimizer)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
cnt += 1
|
||||||
|
if cnt % 100 == 0:
|
||||||
|
self.scheduler.step()
|
||||||
|
step_bar.update()
|
||||||
|
step_bar.close()
|
||||||
|
|
||||||
|
def _before_fit(self,
|
||||||
|
train_dataloader: DataLoader,
|
||||||
|
valid_dataloader: DataLoader,
|
||||||
|
eval_dataloader: DataLoader):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
train_dataloader (DataLoader): the dataloader to use for training
|
||||||
|
valid_dataloader (DataLoader): the dataloader to use for validation
|
||||||
|
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
||||||
|
"""
|
||||||
|
super()._before_fit()
|
||||||
|
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||||||
|
|
||||||
self.train_dataloader = train_dataloader
|
self.train_dataloader = train_dataloader
|
||||||
self.valid_dataloader = valid_dataloader
|
self.valid_dataloader = valid_dataloader
|
||||||
self.eval_dataloader = eval_dataloader
|
self.eval_dataloader = eval_dataloader
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.loss_fn = loss_fn
|
|
||||||
self.optimizer = optim
|
|
||||||
self.scheduler = lr_scheduler
|
|
||||||
|
|
||||||
def eval_acc(self, dataloader):
|
|
||||||
dist = 0
|
|
||||||
on = 0
|
|
||||||
cnt = 0
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
|
|
||||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
|
||||||
for i in range(len(chosen_reward)):
|
|
||||||
cnt += 1
|
|
||||||
if chosen_reward[i] > reject_reward[i]:
|
|
||||||
on += 1
|
|
||||||
dist += (chosen_reward - reject_reward).mean().item()
|
|
||||||
dist_mean = dist / len(dataloader)
|
|
||||||
acc = on / cnt
|
|
||||||
self.model.train()
|
|
||||||
return dist_mean, acc
|
|
||||||
|
|
||||||
def fit(self):
|
|
||||||
time = datetime.now()
|
|
||||||
epoch_bar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
|
||||||
for epoch in range(self.max_epochs):
|
|
||||||
step_bar = tqdm(range(self.train_dataloader.__len__()),
|
|
||||||
desc='Train step of epoch %d' % epoch,
|
|
||||||
disable=not is_rank_0())
|
|
||||||
# train
|
|
||||||
self.model.train()
|
|
||||||
cnt = 0
|
|
||||||
acc = 0
|
|
||||||
dist = 0
|
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
|
||||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
|
||||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
|
||||||
self.strategy.backward(loss, self.model, self.optimizer)
|
|
||||||
self.strategy.optimizer_step(self.optimizer)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
cnt += 1
|
|
||||||
if cnt == 100:
|
|
||||||
self.scheduler.step()
|
|
||||||
dist, acc = self.eval_acc(self.valid_dataloader)
|
|
||||||
cnt = 0
|
|
||||||
if is_rank_0():
|
|
||||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
|
|
||||||
columns=['step', 'loss', 'dist', 'acc'])
|
|
||||||
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
|
|
||||||
step_bar.update()
|
|
||||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
|
||||||
|
|
||||||
# eval
|
|
||||||
dist, acc = self.eval_acc(self.eval_dataloader)
|
|
||||||
if is_rank_0():
|
|
||||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
|
|
||||||
columns=['step', 'loss', 'dist', 'acc'])
|
|
||||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
|
||||||
epoch_bar.update()
|
|
||||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
|
||||||
step_bar.close()
|
|
||||||
|
|
|
@ -1,21 +1,22 @@
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import tqdm
|
||||||
import wandb
|
import wandb
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .base import Trainer
|
from colossalai.logging import DistributedLogger
|
||||||
from .callbacks import Callback
|
|
||||||
|
from .base import SLTrainer
|
||||||
from .strategies import ColossalAIStrategy, Strategy
|
from .strategies import ColossalAIStrategy, Strategy
|
||||||
from .utils import is_rank_0, to_device
|
from .utils import is_rank_0, to_device
|
||||||
|
|
||||||
|
|
||||||
class SFTTrainer(Trainer):
|
class SFTTrainer(SLTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer to use while training reward model.
|
Trainer to use while training reward model.
|
||||||
|
|
||||||
|
@ -23,12 +24,9 @@ class SFTTrainer(Trainer):
|
||||||
model (torch.nn.Module): the model to train
|
model (torch.nn.Module): the model to train
|
||||||
strategy (Strategy): the strategy to use for training
|
strategy (Strategy): the strategy to use for training
|
||||||
optim(Optimizer): the optimizer to use for training
|
optim(Optimizer): the optimizer to use for training
|
||||||
train_dataloader: the dataloader to use for training
|
lr_scheduler(_LRScheduler): the lr scheduler to use for training
|
||||||
eval_dataloader: the dataloader to use for evaluation
|
|
||||||
batch_size (int, defaults to 1): the batch size while training
|
|
||||||
max_epochs (int, defaults to 2): the number of epochs to train
|
max_epochs (int, defaults to 2): the number of epochs to train
|
||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
|
||||||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -37,95 +35,92 @@ class SFTTrainer(Trainer):
|
||||||
strategy: Strategy,
|
strategy: Strategy,
|
||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
lr_scheduler: _LRScheduler,
|
lr_scheduler: _LRScheduler,
|
||||||
train_dataloader: DataLoader,
|
|
||||||
eval_dataloader: DataLoader = None,
|
|
||||||
max_epochs: int = 2,
|
max_epochs: int = 2,
|
||||||
accumulation_steps: int = 8,
|
accumulation_steps: int = 8,
|
||||||
callbacks: List[Callback] = [],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
|
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy):
|
||||||
from colossalai.booster.plugin import GeminiPlugin
|
from colossalai.booster.plugin import GeminiPlugin
|
||||||
assert not isinstance(strategy.plugin, GeminiPlugin), \
|
assert not isinstance(strategy.plugin, GeminiPlugin), \
|
||||||
"Accumulation steps are not supported in stage 3 of ColossalAI"
|
"Accumulation steps are not supported in stage 3 of ColossalAI"
|
||||||
super().__init__(strategy, max_epochs, callbacks=callbacks)
|
|
||||||
self.train_dataloader = train_dataloader
|
super().__init__(strategy, max_epochs, model, optim)
|
||||||
self.eval_dataloader = eval_dataloader
|
|
||||||
self.model = model
|
|
||||||
self.optimizer = optim
|
|
||||||
|
|
||||||
self.accumulation_steps = accumulation_steps
|
self.accumulation_steps = accumulation_steps
|
||||||
|
|
||||||
self.scheduler = lr_scheduler
|
self.scheduler = lr_scheduler
|
||||||
|
|
||||||
def fit(self, logger, use_wandb: bool = False):
|
def _train(self, epoch: int):
|
||||||
|
self.model.train()
|
||||||
|
for batch_id, batch in enumerate(self.train_dataloader):
|
||||||
|
|
||||||
|
batch = to_device(batch, torch.cuda.current_device())
|
||||||
|
outputs = self.model(batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
labels=batch["labels"])
|
||||||
|
|
||||||
|
loss = outputs.loss
|
||||||
|
loss = loss / self.accumulation_steps
|
||||||
|
|
||||||
|
self.strategy.backward(loss, self.model, self.optimizer)
|
||||||
|
|
||||||
|
self.total_loss += loss.item()
|
||||||
|
|
||||||
|
# gradient accumulation
|
||||||
|
if (batch_id + 1) % self.accumulation_steps == 0:
|
||||||
|
self.strategy.optimizer_step(self.optimizer)
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self.scheduler.step()
|
||||||
|
if is_rank_0() and self.use_wandb:
|
||||||
|
wandb.log({
|
||||||
|
"loss": self.total_loss / self.accumulation_steps,
|
||||||
|
"lr": self.scheduler.get_last_lr()[0],
|
||||||
|
"epoch": epoch,
|
||||||
|
"batch_id": batch_id
|
||||||
|
})
|
||||||
|
self.total_loss = 0
|
||||||
|
self.step_bar.update()
|
||||||
|
|
||||||
|
def _eval(self, epoch: int):
|
||||||
|
if self.eval_dataloader is not None:
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
loss_sum, num_seen = 0, 0
|
||||||
|
for batch in self.eval_dataloader:
|
||||||
|
batch = to_device(batch, torch.cuda.current_device())
|
||||||
|
outputs = self.model(batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
labels=batch["labels"])
|
||||||
|
loss = outputs.loss
|
||||||
|
|
||||||
|
loss_sum += loss.item()
|
||||||
|
num_seen += batch["input_ids"].size(0)
|
||||||
|
|
||||||
|
loss_mean = loss_sum / num_seen
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
|
||||||
|
|
||||||
|
def _before_fit(self,
|
||||||
|
train_dataloader: DataLoader,
|
||||||
|
eval_dataloader: Optional[DataLoader] = None,
|
||||||
|
logger: Optional[DistributedLogger] = None,
|
||||||
|
use_wandb: bool = False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
train_dataloader: the dataloader to use for training
|
||||||
|
eval_dataloader: the dataloader to use for evaluation
|
||||||
|
"""
|
||||||
|
self.train_dataloader = train_dataloader
|
||||||
|
self.eval_dataloader = eval_dataloader
|
||||||
|
|
||||||
|
self.logger = logger
|
||||||
|
self.use_wandb = use_wandb
|
||||||
if use_wandb:
|
if use_wandb:
|
||||||
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||||
wandb.watch(self.model)
|
wandb.watch(self.model)
|
||||||
total_loss = 0
|
|
||||||
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
|
|
||||||
step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
|
|
||||||
desc=f'steps',
|
|
||||||
disable=not is_rank_0())
|
|
||||||
for epoch in range(self.max_epochs):
|
|
||||||
|
|
||||||
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
|
self.total_loss = 0
|
||||||
# train
|
self.no_epoch_bar = True
|
||||||
self.model.train()
|
self.step_bar = tqdm.trange(
|
||||||
for batch_id, batch in enumerate(self.train_dataloader):
|
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
|
||||||
|
desc=f'steps',
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
disable=not is_rank_0()
|
||||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
)
|
||||||
|
|
||||||
loss = outputs.loss
|
|
||||||
|
|
||||||
if loss >= 2.5 and is_rank_0():
|
|
||||||
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
|
|
||||||
|
|
||||||
loss = loss / self.accumulation_steps
|
|
||||||
|
|
||||||
self.strategy.backward(loss, self.model, self.optimizer)
|
|
||||||
|
|
||||||
total_loss += loss.item()
|
|
||||||
|
|
||||||
# gradient accumulation
|
|
||||||
if (batch_id + 1) % self.accumulation_steps == 0:
|
|
||||||
self.strategy.optimizer_step(self.optimizer)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
self.scheduler.step()
|
|
||||||
if is_rank_0() and use_wandb:
|
|
||||||
wandb.log({
|
|
||||||
"loss": total_loss / self.accumulation_steps,
|
|
||||||
"lr": self.scheduler.get_last_lr()[0],
|
|
||||||
"epoch": epoch,
|
|
||||||
"batch_id": batch_id
|
|
||||||
})
|
|
||||||
total_loss = 0
|
|
||||||
step_bar.update()
|
|
||||||
|
|
||||||
# if batch_id % log_interval == 0:
|
|
||||||
# logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
|
|
||||||
# wandb.log({"loss": loss.item()})
|
|
||||||
|
|
||||||
# process_bar.update()
|
|
||||||
|
|
||||||
# eval
|
|
||||||
if self.eval_dataloader is not None:
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
loss_sum = 0
|
|
||||||
num_seen = 0
|
|
||||||
for batch in self.eval_dataloader:
|
|
||||||
batch = to_device(batch, torch.cuda.current_device())
|
|
||||||
outputs = self.model(batch["input_ids"],
|
|
||||||
attention_mask=batch["attention_mask"],
|
|
||||||
labels=batch["labels"])
|
|
||||||
loss = outputs.loss
|
|
||||||
|
|
||||||
loss_sum += loss.item()
|
|
||||||
num_seen += batch["input_ids"].size(0)
|
|
||||||
|
|
||||||
loss_mean = loss_sum / num_seen
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
|
|
||||||
|
|
||||||
# epoch_bar.update()
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import functools
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -103,7 +102,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
# NOTE: dist should be initialized before calling get_current_device()
|
# NOTE: dist should be initialized before calling get_current_device()
|
||||||
if stage == 3:
|
if stage == 3:
|
||||||
plugin_initializer = lambda: GeminiPlugin(
|
plugin_initializer = lambda: GeminiPlugin(
|
||||||
# gemini_config
|
# gemini_config
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
placement_policy=placement_policy,
|
placement_policy=placement_policy,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
|
@ -113,20 +112,20 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
search_range_m=search_range_m,
|
search_range_m=search_range_m,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
min_chunk_size_m=min_chunk_size_m,
|
min_chunk_size_m=min_chunk_size_m,
|
||||||
# zero_optim_config
|
# zero_optim_config
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||||
# optim_config
|
# optim_config
|
||||||
**optim_kwargs)
|
**optim_kwargs)
|
||||||
else:
|
else:
|
||||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||||
# zero_config
|
# zero_config
|
||||||
stage=stage,
|
stage=stage,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
# zero_optim_config
|
# zero_optim_config
|
||||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||||
overlap_communication=overlap_communication,
|
overlap_communication=overlap_communication,
|
||||||
cpu_offload=(placement_policy == 'cpu'),
|
cpu_offload=(placement_policy == 'cpu'),
|
||||||
# optim_config
|
# optim_config
|
||||||
**optim_kwargs)
|
**optim_kwargs)
|
||||||
|
|
||||||
super().__init__(seed, plugin_initializer)
|
super().__init__(seed, plugin_initializer)
|
||||||
|
|
|
@ -3,6 +3,33 @@ from typing import Any
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class CycledDataLoader:
|
||||||
|
"""
|
||||||
|
Why do we need this class?
|
||||||
|
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
|
||||||
|
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
|
||||||
|
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
) -> None:
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
self.count = 0
|
||||||
|
self.dataloader_iter = iter(dataloader)
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
self.count += 1
|
||||||
|
try:
|
||||||
|
return next(self.dataloader_iter)
|
||||||
|
except StopIteration:
|
||||||
|
self.count = 0
|
||||||
|
self.dataloader_iter = iter(self.dataloader)
|
||||||
|
return next(self.dataloader_iter)
|
||||||
|
|
||||||
|
|
||||||
def is_rank_0() -> bool:
|
def is_rank_0() -> bool:
|
||||||
|
|
|
@ -171,9 +171,8 @@ Pretrain dataset: the pretrain dataset including the instruction and correspondi
|
||||||
- --pretrain_dataset: path of the ptx dataset, type=str, default=None
|
- --pretrain_dataset: path of the ptx dataset, type=str, default=None
|
||||||
- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
|
- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
|
||||||
- --num_episodes: num of episodes for training, type=int, default=10
|
- --num_episodes: num of episodes for training, type=int, default=10
|
||||||
- --max_epochs: max epochs for training in one episode, type=int, default=5
|
- --num_update_steps: number of steps to update policy per episode, type=int
|
||||||
- --max_timesteps: max episodes in one batch, type=int, default=10
|
- --num_collect_steps: number of steps to collect experience per episode, type=int
|
||||||
- --update_timesteps: timesteps to update, type=int, default=10
|
|
||||||
- --train_batch_size: batch size while training, type=int, default=8
|
- --train_batch_size: batch size while training, type=int, default=8
|
||||||
- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
|
- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
|
||||||
- --experience_batch_size: batch size to make experience, type=int, default=8
|
- --experience_batch_size: batch size to make experience, type=int, default=8
|
||||||
|
|
|
@ -171,7 +171,6 @@ def main(args):
|
||||||
critic_optim,
|
critic_optim,
|
||||||
kl_coef=args.kl_coef,
|
kl_coef=args.kl_coef,
|
||||||
ptx_coef=args.ptx_coef,
|
ptx_coef=args.ptx_coef,
|
||||||
max_epochs=args.max_epochs,
|
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
experience_batch_size=args.experience_batch_size,
|
experience_batch_size=args.experience_batch_size,
|
||||||
tokenizer=tokenize_fn,
|
tokenizer=tokenize_fn,
|
||||||
|
@ -186,8 +185,8 @@ def main(args):
|
||||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||||
pretrain_dataloader=pretrain_dataloader,
|
pretrain_dataloader=pretrain_dataloader,
|
||||||
num_episodes=args.num_episodes,
|
num_episodes=args.num_episodes,
|
||||||
max_timesteps=args.max_timesteps,
|
num_update_steps=args.num_update_steps,
|
||||||
update_timesteps=args.update_timesteps)
|
num_collect_steps=args.num_collect_steps)
|
||||||
|
|
||||||
# save model checkpoint after fitting
|
# save model checkpoint after fitting
|
||||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||||
|
@ -215,9 +214,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
parser.add_argument('--num_episodes', type=int, default=10)
|
parser.add_argument('--num_episodes', type=int, default=10)
|
||||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||||
parser.add_argument('--max_epochs', type=int, default=5)
|
|
||||||
parser.add_argument('--train_batch_size', type=int, default=2)
|
parser.add_argument('--train_batch_size', type=int, default=2)
|
||||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
|
|
|
@ -63,8 +63,8 @@ for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||||
--strategy $strategy --model $model \
|
--strategy $strategy --model $model \
|
||||||
--num_episodes 1 --max_timesteps 2 \
|
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2
|
--train_batch_size 2
|
||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
||||||
|
@ -149,8 +149,8 @@ rm -rf ${BASE}/rm_ckpt.pt
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
--strategy colossalai_zero2 --num_episodes 1 \
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||||
--pretrain 'facebook/opt-350m' --model opt \
|
--pretrain 'facebook/opt-350m' --model opt \
|
||||||
--rm_pretrain 'facebook/opt-350m' \
|
--rm_pretrain 'facebook/opt-350m' \
|
||||||
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
||||||
|
@ -159,8 +159,8 @@ rm -rf ${BASE}/rm_ckpt_opt.pt
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
--strategy colossalai_zero2 --num_episodes 1 \
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||||
--pretrain 'gpt2' --model gpt2 \
|
--pretrain 'gpt2' --model gpt2 \
|
||||||
--rm_pretrain 'gpt2' \
|
--rm_pretrain 'gpt2' \
|
||||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||||
|
@ -168,8 +168,8 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
--strategy colossalai_gemini --num_episodes 1 \
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||||
--pretrain 'gpt2' --model gpt2 \
|
--pretrain 'gpt2' --model gpt2 \
|
||||||
--rm_pretrain 'gpt2' \
|
--rm_pretrain 'gpt2' \
|
||||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||||
|
|
|
@ -177,7 +177,6 @@ def main(args):
|
||||||
critic_optim,
|
critic_optim,
|
||||||
kl_coef=args.kl_coef,
|
kl_coef=args.kl_coef,
|
||||||
ptx_coef=args.ptx_coef,
|
ptx_coef=args.ptx_coef,
|
||||||
max_epochs=args.max_epochs,
|
|
||||||
train_batch_size=args.train_batch_size,
|
train_batch_size=args.train_batch_size,
|
||||||
max_length=args.max_seq_len,
|
max_length=args.max_seq_len,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
@ -192,8 +191,8 @@ def main(args):
|
||||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||||
pretrain_dataloader=pretrain_dataloader,
|
pretrain_dataloader=pretrain_dataloader,
|
||||||
num_episodes=args.num_episodes,
|
num_episodes=args.num_episodes,
|
||||||
max_timesteps=args.max_timesteps,
|
num_collect_steps=args.num_collect_steps,
|
||||||
update_timesteps=args.update_timesteps)
|
num_update_steps=args.num_update_steps)
|
||||||
|
|
||||||
# save model checkpoint after fitting
|
# save model checkpoint after fitting
|
||||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
strategy.save_model(actor, args.save_path, only_rank0=True)
|
||||||
|
@ -220,9 +219,8 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
parser.add_argument('--num_episodes', type=int, default=10)
|
parser.add_argument('--num_episodes', type=int, default=10)
|
||||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||||
parser.add_argument('--max_epochs', type=int, default=5)
|
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
local n=${1:-"9999"}
|
local n=${1:-"9999"}
|
||||||
echo "GPU Memory Usage:"
|
echo "GPU Memory Usage:"
|
||||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||||
| tail -n +2 \
|
tail -n +2 |
|
||||||
| nl -v 0 \
|
nl -v 0 |
|
||||||
| tee /dev/tty \
|
tee /dev/tty |
|
||||||
| sort -g -k 2 \
|
sort -g -k 2 |
|
||||||
| awk '{print $1}' \
|
awk '{print $1}' |
|
||||||
| head -n $n)
|
head -n $n)
|
||||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||||
|
|
||||||
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2
|
torchrun --standalone --nproc_per_node=2 train_prompts.py \
|
||||||
|
--pretrain_dataset /path/to/data.json \
|
||||||
|
--prompt_dataset /path/to/data.json \
|
||||||
|
--strategy colossalai_zero2 \
|
||||||
|
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
||||||
|
--train_batch_size 2
|
||||||
|
|
|
@ -178,12 +178,11 @@ def train(args):
|
||||||
optim=optim,
|
optim=optim,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
valid_dataloader=valid_dataloader,
|
|
||||||
eval_dataloader=eval_dataloader,
|
|
||||||
max_epochs=args.max_epochs)
|
max_epochs=args.max_epochs)
|
||||||
|
|
||||||
trainer.fit()
|
trainer.fit(train_dataloader=train_dataloader,
|
||||||
|
valid_dataloader=valid_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader)
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||||
# save optimizer checkpoint on all ranks
|
# save optimizer checkpoint on all ranks
|
||||||
|
|
|
@ -170,12 +170,13 @@ def train(args):
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
eval_dataloader=eval_dataloader,
|
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accumulation_steps=args.accumulation_steps)
|
accumulation_steps=args.accumulation_steps)
|
||||||
|
|
||||||
trainer.fit(logger=logger, use_wandb=args.use_wandb)
|
trainer.fit(train_dataloader=train_dataloader,
|
||||||
|
eval_dataloader=eval_dataloader,
|
||||||
|
logger=logger,
|
||||||
|
use_wandb=args.use_wandb)
|
||||||
|
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||||
|
|
Loading…
Reference in New Issue