[chat] refactor model save/load logic (#3654)

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test
pull/3662/head
Hongxin Liu 2023-04-27 18:41:49 +08:00 committed by GitHub
parent 6ef7011462
commit 842768a174
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 155 additions and 181 deletions

View File

@ -243,6 +243,7 @@ from coati.trainer import SFTTrainer
model = LlamaLM(pretrained=args.pretrain) model = LlamaLM(pretrained=args.pretrain)
tokenizer = AutoTokenizer.from_pretrained(args.pretrain) tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
(model, optim) = strategy.prepare((model, optim))
trainer = SFTTrainer(model=model, trainer = SFTTrainer(model=model,
strategy=strategy, strategy=strategy,
optim=optim, optim=optim,
@ -254,7 +255,11 @@ trainer = SFTTrainer(model=model,
) )
trainer.fit() trainer.fit()
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) # this saves in pytorch format
strategy.save_model(model, args.save_path, only_rank0=True)
# this saves in HF format. ColossalAI strategy with stage-3 doesn't support this method
strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer)
``` ```
</details> </details>
@ -263,7 +268,7 @@ trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs. Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
If you only have a single 24G GPU, you can use the following script. `batch_size` and `lora_rank` are the most important parameters to successfully train the model. If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model.
``` ```
torchrun --standalone --nproc_per_node=1 train_sft.py \ torchrun --standalone --nproc_per_node=1 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \ --pretrain "/path/to/LLaMa-7B/" \
@ -278,6 +283,7 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \
--max_datasets_size 512 \ --max_datasets_size 512 \
--max_epochs 1 \ --max_epochs 1 \
--lora_rank 16 \ --lora_rank 16 \
--grad_checkpoint
``` ```
`colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script. `colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script.
@ -294,6 +300,7 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \
--lr 2e-5 \ --lr 2e-5 \
--max_datasets_size 512 \ --max_datasets_size 512 \
--max_epochs 1 \ --max_epochs 1 \
--grad_checkpoint
``` ```
If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows. If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows.
@ -310,6 +317,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--lr 2e-5 \ --lr 2e-5 \
--max_datasets_size 512 \ --max_datasets_size 512 \
--max_epochs 1 \ --max_epochs 1 \
--grad_checkpoint
``` ```
</details> </details>

View File

@ -1,5 +1,24 @@
import torch.nn as nn
from .actor import Actor from .actor import Actor
from .critic import Critic from .critic import Critic
from .reward_model import RewardModel from .reward_model import RewardModel
__all__ = ['Actor', 'Critic', 'RewardModel']
def get_base_model(model: nn.Module) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
For Critic and RewardModel, it's base model is itself.
Args:
model (nn.Module): model to get base model from
Returns:
nn.Module: the base model
"""
if isinstance(model, Actor):
return model.get_base_model()
return model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']

View File

@ -199,15 +199,9 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()} return {'reward': experience.reward.mean().item()}
def save_model(self,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor) origin_model = strategy.unwrap_model(actor)
new_kwargs = {**generate_kwargs} new_kwargs = {**generate_kwargs}
# use huggingface models method directly # use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Optional, List from typing import List, Optional
import pandas as pd import pandas as pd
import torch import torch
@ -9,8 +9,8 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .callbacks import Callback
from .base import Trainer from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy from .strategies import Strategy
from .utils import is_rank_0 from .utils import is_rank_0
@ -41,20 +41,18 @@ class RewardModelTrainer(Trainer):
train_dataloader: DataLoader, train_dataloader: DataLoader,
valid_dataloader: DataLoader, valid_dataloader: DataLoader,
eval_dataloader: DataLoader, eval_dataloader: DataLoader,
batch_size: int = 1,
max_epochs: int = 1, max_epochs: int = 1,
callbacks: List[Callback] = [], callbacks: List[Callback] = [],
) -> None: ) -> None:
super().__init__(strategy, max_epochs, callbacks=callbacks) super().__init__(strategy, max_epochs, callbacks=callbacks)
train_sampler = None
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
self.model = strategy.setup_model(model) self.model = model
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = strategy.setup_optimizer(optim, self.model) self.optimizer = optim
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100) self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
def eval_acc(self, dataloader): def eval_acc(self, dataloader):
@ -123,9 +121,3 @@ class RewardModelTrainer(Trainer):
epoch_bar.update() epoch_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc}) step_bar.set_postfix({'dist': dist, 'acc': acc})
step_bar.close() step_bar.close()
def save_model(self,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)

View File

@ -49,8 +49,8 @@ class SFTTrainer(Trainer):
super().__init__(strategy, max_epochs, callbacks=callbacks) super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader self.eval_dataloader = eval_dataloader
self.model = model
(self.model, self.optimizer) = strategy.prepare((model, optim)) self.optimizer = optim
self.accimulation_steps = accimulation_steps self.accimulation_steps = accimulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
@ -133,9 +133,3 @@ class SFTTrainer(Trainer):
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
# epoch_bar.update() # epoch_bar.update()
def save_model(self,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)

View File

@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -72,8 +71,8 @@ class Strategy(ABC):
def prepare_model(model: nn.Module): def prepare_model(model: nn.Module):
if isinstance(model, Actor): if isinstance(model, Actor):
return Actor(self.setup_model(self._unwrap_model(model))) return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(self._unwrap_model(model)) return self.setup_model(model)
rets = [] rets = []
for arg in models_or_model_optim_pairs: for arg in models_or_model_optim_pairs:
@ -81,7 +80,7 @@ class Strategy(ABC):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg model, optimizer = arg
model = prepare_model(model) model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model)) optimizer = self.setup_optimizer(optimizer, get_base_model(model))
rets.append((model, optimizer)) rets.append((model, optimizer))
elif isinstance(arg, nn.Module): elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg)) rets.append(prepare_model(arg))
@ -93,31 +92,20 @@ class Strategy(ABC):
return rets return rets
@staticmethod @staticmethod
def _unwrap_model(model: nn.Module) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving. """Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
Args: Args:
model (nn.Module): an actor or a critic model (nn.Module): the model to unwrap
"""
if isinstance(model, Actor):
return model.model
return model
@staticmethod Returns:
def _unwrap_actor(actor: Actor) -> nn.Module: nn.Module: the original model (usually a huggingface model)
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
Args:
actor (Actor): a wrapped actor
""" """
return Strategy._unwrap_model(actor) return get_base_model(model)
@abstractmethod @abstractmethod
def save_model(self, def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
pass pass
@abstractmethod @abstractmethod
@ -134,3 +122,11 @@ class Strategy(ABC):
def setup_sampler(self, dataset) -> DistributedSampler: def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, 1, 0) return DistributedSampler(dataset, 1, 0)
@abstractmethod
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
pass

View File

@ -5,10 +5,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from coati.models.base import Actor, RewardModel from coati.models.base import get_base_model
from coati.models.lora import LoraLinear
from torch.optim import Optimizer from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai import colossalai
@ -17,9 +15,7 @@ from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.utils import get_static_torch_model
from .base import Strategy
from .ddp import DDPStrategy from .ddp import DDPStrategy
logger = get_dist_logger(__name__) logger = get_dist_logger(__name__)
@ -141,7 +137,7 @@ class ColossalAIStrategy(DDPStrategy):
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config) model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
if self.stage != 3 and self.precision == 'fp16': if self.stage != 3 and self.precision == 'fp16':
model = model.half() model = model.half().cuda()
return model return model
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer: def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
@ -154,47 +150,39 @@ class ColossalAIStrategy(DDPStrategy):
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step() optimizer.step()
@staticmethod def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
def _unwrap_actor(actor: Actor) -> nn.Module: if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor) return
if isinstance(model, ZeroDDP): base_model = get_base_model(model)
return model.module if self.stage == 3:
return model assert isinstance(base_model, ZeroDDP)
# for stage 3, state_dict() method should be called on every rank
def save_model(self, state_dict = base_model.state_dict(only_rank_0=only_rank0)
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None
unwrapped_model = self._unwrap_model(model)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
if isinstance(unwrapped_model, RewardModel):
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
else: else:
try: # only_rank0 is false or rank == 0
logger.info(f'Saving model to {path}', ranks=[0]) state_dict = base_model.state_dict()
unwrapped_model.save_pretrained(path) if only_rank0 and dist.get_rank() != 0:
logger.info(f'Model saved to {path} Successfully', ranks=[0]) return
if tokenizer is not None: torch.save(state_dict, path)
logger.info(f'Saving tokenizer to {path}', ranks=[0])
tokenizer.save_pretrained(path)
logger.info(f'Tokenizer saved to {path} Successfully', ranks=[0])
except AttributeError:
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0: if only_rank0:
raise RuntimeError( raise RuntimeError(
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.') f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
torch.save(optimizer.state_dict(), path) torch.save(optimizer.state_dict(), path)
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
return base_model.module
return base_model
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if self.stage == 3:
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
super().save_pretrained(model, path, only_rank0, tokenizer)

View File

@ -6,14 +6,12 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from coati.models.base import Actor, RewardModel
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy
from .naive import NaiveStrategy from .naive import NaiveStrategy
from .sampler import DistributedSampler from .sampler import DistributedSampler
@ -68,34 +66,10 @@ class DDPStrategy(NaiveStrategy):
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
@staticmethod def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
def _unwrap_actor(actor: Actor) -> nn.Module:
model: DDP = Strategy._unwrap_actor(actor)
return model.module
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:
return None return
super().save_model(model, path, only_rank0)
if isinstance(model, RewardModel):
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
else:
try:
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0: if only_rank0 and dist.get_rank() != 0:
@ -104,3 +78,16 @@ class DDPStrategy(NaiveStrategy):
def setup_sampler(self, dataset) -> DistributedSampler: def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: DDP = super().unwrap_model(model)
return base_model.module
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_pretrained(model, path, only_rank0, tokenizer)

View File

@ -3,10 +3,11 @@ from typing import Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from coati.models.base import RewardModel from coati.models.base import get_base_model
from coati.replay_buffer import ReplayBuffer from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy from .base import Strategy
@ -40,27 +41,15 @@ class NaiveStrategy(Strategy):
pin_memory=pin_memory, pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn) collate_fn=replay_buffer.collate_fn)
def save_model(self, def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
model: nn.Module, base_model = get_base_model(model)
path: str, state_dict = base_model.state_dict()
only_rank0: bool = False, torch.save(state_dict, path)
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
else:
try:
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model) base_model = get_base_model(model)
state_dict = torch.load(path, map_location=map_location) state_dict = torch.load(path, map_location=map_location)
unwrapped_model.load_state_dict(state_dict, strict=strict) base_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path) torch.save(optimizer.state_dict(), path)
@ -68,3 +57,14 @@ class NaiveStrategy(Strategy):
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None: def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
state_dict = torch.load(path, map_location=map_location) state_dict = torch.load(path, map_location=map_location)
optimizer.load_state_dict(state_dict) optimizer.load_state_dict(state_dict)
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)

View File

@ -66,6 +66,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--lr 2e-5 \ --lr 2e-5 \
--max_datasets_size 512 \ --max_datasets_size 512 \
--max_epochs 1 \ --max_epochs 1 \
--grad_checkpoint
``` ```
### Arg List ### Arg List
- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive' - --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive'
@ -78,6 +79,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
- --batch_size: batch size while training, type=int, default=4 - --batch_size: batch size while training, type=int, default=4
- --lora_rank: low-rank adaptation matrices rank, type=int, default=0 - --lora_rank: low-rank adaptation matrices rank, type=int, default=0
- --log_interval: how many steps to log, type=int, default=100 - --log_interval: how many steps to log, type=int, default=100
- --grad_checkpoint: enable gradient checkpointing, type=bool, default=False
## Stage2 - Training reward model ## Stage2 - Training reward model
@ -254,29 +256,6 @@ class CoatiActor(Actor):
super().__init__(model, lora_rank, lora_train_bias) super().__init__(model, lora_rank, lora_train_bias)
``` ```
### LM model
```
from ..base import LM
from transformers.models.coati import CoatiModel
class GPTLM(LM):
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = CoatiModel.from_pretrained(pretrained)
else:
model = build_model() # load your own model if it is not support in transformers
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
```
### Reward model ### Reward model
``` ```
from ..base import RewardModel from ..base import RewardModel

View File

@ -194,7 +194,7 @@ def main(args):
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
# save model checkpoint after fitting # save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim, strategy.save_optimizer(actor_optim,

View File

@ -124,11 +124,23 @@ def train(args):
raise ValueError(f'Unsupported dataset "{args.dataset}"') raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1: if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(), train_sampler = DistributedSampler(train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size()) num_replicas=dist.get_world_size())
valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(), valid_sampler = DistributedSampler(valid_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size()) num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(), eval_sampler = DistributedSampler(eval_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size()) num_replicas=dist.get_world_size())
else: else:
train_sampler = None train_sampler = None
@ -141,13 +153,19 @@ def train(args):
batch_size=args.batch_size, batch_size=args.batch_size,
pin_memory=True) pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None), valid_dataloader = DataLoader(valid_dataset,
shuffle=(valid_sampler is None),
sampler=valid_sampler, sampler=valid_sampler,
batch_size=args.batch_size, pin_memory=True) batch_size=args.batch_size,
pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), eval_dataloader = DataLoader(eval_dataset,
sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True) shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
pin_memory=True)
(model, optim) = strategy.prepare((model, optim))
trainer = RewardModelTrainer(model=model, trainer = RewardModelTrainer(model=model,
strategy=strategy, strategy=strategy,
optim=optim, optim=optim,
@ -155,12 +173,11 @@ def train(args):
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader, valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader, eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs) max_epochs=args.max_epochs)
trainer.fit() trainer.fit()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer, strategy.save_optimizer(trainer.optimizer,

View File

@ -152,6 +152,7 @@ def train(args):
else: else:
eval_dataloader = None eval_dataloader = None
(model, optim) = strategy.prepare((model, optim))
trainer = SFTTrainer(model=model, trainer = SFTTrainer(model=model,
strategy=strategy, strategy=strategy,
optim=optim, optim=optim,
@ -163,7 +164,7 @@ def train(args):
trainer.fit(logger=logger, use_wandb=args.use_wandb) trainer.fit(logger=logger, use_wandb=args.use_wandb)
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt: if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer, strategy.save_optimizer(trainer.optimizer,

View File

@ -82,7 +82,6 @@ def run_dist(rank, world_size, port, strategy):
run_test_checkpoint(strategy) run_test_checkpoint(strategy)
@pytest.mark.skip('temporarily skip until refactor strategy unwrap')
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])