[chat] refactor model save/load logic (#3654)

* [chat] strategy refactor unwrap model

* [chat] strategy refactor save model

* [chat] add docstr

* [chat] refactor trainer save model

* [chat] fix strategy typing

* [chat] refactor trainer save model

* [chat] update readme

* [chat] fix unit test
pull/3662/head
Hongxin Liu 2023-04-27 18:41:49 +08:00 committed by GitHub
parent 6ef7011462
commit 842768a174
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 155 additions and 181 deletions

View File

@ -243,6 +243,7 @@ from coati.trainer import SFTTrainer
model = LlamaLM(pretrained=args.pretrain)
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
(model, optim) = strategy.prepare((model, optim))
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
@ -254,7 +255,11 @@ trainer = SFTTrainer(model=model,
)
trainer.fit()
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# this saves in pytorch format
strategy.save_model(model, args.save_path, only_rank0=True)
# this saves in HF format. ColossalAI strategy with stage-3 doesn't support this method
strategy.save_pretrained(model, args.save_path, only_rank0=True, tokenizer=tokenizer)
```
</details>
@ -263,7 +268,7 @@ trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
Here are some examples that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
If you only have a single 24G GPU, you can use the following script. `batch_size` and `lora_rank` are the most important parameters to successfully train the model.
If you only have a single 24G GPU, you can use the following script. `batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model.
```
torchrun --standalone --nproc_per_node=1 train_sft.py \
--pretrain "/path/to/LLaMa-7B/" \
@ -278,6 +283,7 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \
--max_datasets_size 512 \
--max_epochs 1 \
--lora_rank 16 \
--grad_checkpoint
```
`colossalai_gemini` strategy can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. You can use the following script.
@ -294,6 +300,7 @@ torchrun --standalone --nproc_per_node=1 train_sft.py \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--grad_checkpoint
```
If you have 4x32 GB GPUs, you can even train the whole 7B model using our `colossalai_zero2_cpu` strategy! The script is given as follows.
@ -310,6 +317,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--grad_checkpoint
```
</details>

View File

@ -1,5 +1,24 @@
import torch.nn as nn
from .actor import Actor
from .critic import Critic
from .reward_model import RewardModel
__all__ = ['Actor', 'Critic', 'RewardModel']
def get_base_model(model: nn.Module) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``.
For Critic and RewardModel, it's base model is itself.
Args:
model (nn.Module): model to get base model from
Returns:
nn.Module: the base model
"""
if isinstance(model, Actor):
return model.get_base_model()
return model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']

View File

@ -199,15 +199,9 @@ class PPOTrainer(Trainer):
return {'reward': experience.reward.mean().item()}
def save_model(self,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
origin_model = strategy.unwrap_model(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional, List
from typing import List, Optional
import pandas as pd
import torch
@ -9,8 +9,8 @@ from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .callbacks import Callback
from .base import Trainer
from .callbacks import Callback
from .strategies import Strategy
from .utils import is_rank_0
@ -41,20 +41,18 @@ class RewardModelTrainer(Trainer):
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader,
batch_size: int = 1,
max_epochs: int = 1,
callbacks: List[Callback] = [],
) -> None:
super().__init__(strategy, max_epochs, callbacks=callbacks)
train_sampler = None
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader
self.eval_dataloader = eval_dataloader
self.model = strategy.setup_model(model)
self.model = model
self.loss_fn = loss_fn
self.optimizer = strategy.setup_optimizer(optim, self.model)
self.optimizer = optim
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
def eval_acc(self, dataloader):
@ -123,9 +121,3 @@ class RewardModelTrainer(Trainer):
epoch_bar.update()
step_bar.set_postfix({'dist': dist, 'acc': acc})
step_bar.close()
def save_model(self,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)

View File

@ -49,8 +49,8 @@ class SFTTrainer(Trainer):
super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
(self.model, self.optimizer) = strategy.prepare((model, optim))
self.model = model
self.optimizer = optim
self.accimulation_steps = accimulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
@ -133,9 +133,3 @@ class SFTTrainer(Trainer):
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
# epoch_bar.update()
def save_model(self,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)

View File

@ -2,10 +2,9 @@ from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel
from coati.models.base import Actor, get_base_model
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@ -72,8 +71,8 @@ class Strategy(ABC):
def prepare_model(model: nn.Module):
if isinstance(model, Actor):
return Actor(self.setup_model(self._unwrap_model(model)))
return self.setup_model(self._unwrap_model(model))
return Actor(self.setup_model(model.get_base_model()))
return self.setup_model(model)
rets = []
for arg in models_or_model_optim_pairs:
@ -81,7 +80,7 @@ class Strategy(ABC):
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
model, optimizer = arg
model = prepare_model(model)
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
optimizer = self.setup_optimizer(optimizer, get_base_model(model))
rets.append((model, optimizer))
elif isinstance(arg, nn.Module):
rets.append(prepare_model(arg))
@ -93,31 +92,20 @@ class Strategy(ABC):
return rets
@staticmethod
def _unwrap_model(model: nn.Module) -> nn.Module:
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
def unwrap_model(model: nn.Module) -> nn.Module:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
Args:
model (nn.Module): an actor or a critic
"""
if isinstance(model, Actor):
return model.model
return model
model (nn.Module): the model to unwrap
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
Args:
actor (Actor): a wrapped actor
Returns:
nn.Module: the original model (usually a huggingface model)
"""
return Strategy._unwrap_model(actor)
return get_base_model(model)
@abstractmethod
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
pass
@abstractmethod
@ -134,3 +122,11 @@ class Strategy(ABC):
def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, 1, 0)
@abstractmethod
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
pass

View File

@ -5,10 +5,8 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from coati.models.base import Actor, RewardModel
from coati.models.lora import LoraLinear
from coati.models.base import get_base_model
from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai
@ -17,9 +15,7 @@ from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.utils import get_static_torch_model
from .base import Strategy
from .ddp import DDPStrategy
logger = get_dist_logger(__name__)
@ -141,7 +137,7 @@ class ColossalAIStrategy(DDPStrategy):
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
if self.stage != 3 and self.precision == 'fp16':
model = model.half()
model = model.half().cuda()
return model
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
@ -154,47 +150,39 @@ class ColossalAIStrategy(DDPStrategy):
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
optimizer.step()
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
if isinstance(model, ZeroDDP):
return model.module
return model
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None
unwrapped_model = self._unwrap_model(model)
# TODO : better way to get torch model from gemini model
# to get torch model from gemini model
if isinstance(unwrapped_model, RewardModel):
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0 and self.stage != 3:
return
base_model = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
# for stage 3, state_dict() method should be called on every rank
state_dict = base_model.state_dict(only_rank_0=only_rank0)
else:
try:
logger.info(f'Saving model to {path}', ranks=[0])
unwrapped_model.save_pretrained(path)
logger.info(f'Model saved to {path} Successfully', ranks=[0])
if tokenizer is not None:
logger.info(f'Saving tokenizer to {path}', ranks=[0])
tokenizer.save_pretrained(path)
logger.info(f'Tokenizer saved to {path} Successfully', ranks=[0])
except AttributeError:
state_dict = unwrapped_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
# only_rank0 is false or rank == 0
state_dict = base_model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0:
raise RuntimeError(
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
torch.save(optimizer.state_dict(), path)
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: Union[nn.Module, ZeroDDP] = get_base_model(model)
if self.stage == 3:
assert isinstance(base_model, ZeroDDP)
return base_model.module
return base_model
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if self.stage == 3:
raise RuntimeError('ColossalAI strategy with stage-3 does not support save_pretrained() now')
super().save_pretrained(model, path, only_rank0, tokenizer)

View File

@ -6,14 +6,12 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.base import Actor, RewardModel
from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy
from .naive import NaiveStrategy
from .sampler import DistributedSampler
@ -68,34 +66,10 @@ class DDPStrategy(NaiveStrategy):
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
@staticmethod
def _unwrap_actor(actor: Actor) -> nn.Module:
model: DDP = Strategy._unwrap_actor(actor)
return model.module
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
if only_rank0 and dist.get_rank() != 0:
return None
if isinstance(model, RewardModel):
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
else:
try:
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
return
super().save_model(model, path, only_rank0)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:
@ -104,3 +78,16 @@ class DDPStrategy(NaiveStrategy):
def setup_sampler(self, dataset) -> DistributedSampler:
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
def unwrap_model(self, model: nn.Module) -> nn.Module:
base_model: DDP = super().unwrap_model(model)
return base_model.module
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return
super().save_pretrained(model, path, only_rank0, tokenizer)

View File

@ -3,10 +3,11 @@ from typing import Any, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from coati.models.base import RewardModel
from coati.models.base import get_base_model
from coati.replay_buffer import ReplayBuffer
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy
@ -40,27 +41,15 @@ class NaiveStrategy(Strategy):
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = False,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
else:
try:
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
torch.save(state_dict, path)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None:
base_model = get_base_model(model)
state_dict = base_model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model)
base_model = get_base_model(model)
state_dict = torch.load(path, map_location=map_location)
unwrapped_model.load_state_dict(state_dict, strict=strict)
base_model.load_state_dict(state_dict, strict=strict)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
torch.save(optimizer.state_dict(), path)
@ -68,3 +57,14 @@ class NaiveStrategy(Strategy):
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
state_dict = torch.load(path, map_location=map_location)
optimizer.load_state_dict(state_dict)
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, PreTrainedModel)
unwrapped_model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)

View File

@ -66,6 +66,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
--lr 2e-5 \
--max_datasets_size 512 \
--max_epochs 1 \
--grad_checkpoint
```
### Arg List
- --strategy: the strategy using for training, choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive'
@ -78,6 +79,7 @@ torchrun --standalone --nproc_per_node=4 train_sft.py \
- --batch_size: batch size while training, type=int, default=4
- --lora_rank: low-rank adaptation matrices rank, type=int, default=0
- --log_interval: how many steps to log, type=int, default=100
- --grad_checkpoint: enable gradient checkpointing, type=bool, default=False
## Stage2 - Training reward model
@ -152,7 +154,7 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \
--rm_path /your/rm/model/path
```
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild.
Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning.
### Arg List
@ -254,29 +256,6 @@ class CoatiActor(Actor):
super().__init__(model, lora_rank, lora_train_bias)
```
### LM model
```
from ..base import LM
from transformers.models.coati import CoatiModel
class GPTLM(LM):
def __init__(self,
pretrained: Optional[str] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
if pretrained is not None:
model = CoatiModel.from_pretrained(pretrained)
else:
model = build_model() # load your own model if it is not support in transformers
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
```
### Reward model
```
from ..base import RewardModel

View File

@ -194,7 +194,7 @@ def main(args):
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim,

View File

@ -124,11 +124,23 @@ def train(args):
raise ValueError(f'Unsupported dataset "{args.dataset}"')
if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
train_sampler = DistributedSampler(train_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
valid_sampler = DistributedSampler(valid_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
valid_sampler = DistributedSampler(valid_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
eval_sampler = DistributedSampler(eval_dataset, shuffle=True, seed=42, drop_last=True, rank=dist.get_rank(),
eval_sampler = DistributedSampler(eval_dataset,
shuffle=True,
seed=42,
drop_last=True,
rank=dist.get_rank(),
num_replicas=dist.get_world_size())
else:
train_sampler = None
@ -141,13 +153,19 @@ def train(args):
batch_size=args.batch_size,
pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, shuffle=(valid_sampler is None),
valid_dataloader = DataLoader(valid_dataset,
shuffle=(valid_sampler is None),
sampler=valid_sampler,
batch_size=args.batch_size, pin_memory=True)
batch_size=args.batch_size,
pin_memory=True)
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None),
sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True)
eval_dataloader = DataLoader(eval_dataset,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
batch_size=args.batch_size,
pin_memory=True)
(model, optim) = strategy.prepare((model, optim))
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
@ -155,12 +173,11 @@ def train(args):
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
trainer.fit()
# save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,

View File

@ -152,6 +152,7 @@ def train(args):
else:
eval_dataloader = None
(model, optim) = strategy.prepare((model, optim))
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
@ -163,7 +164,7 @@ def train(args):
trainer.fit(logger=logger, use_wandb=args.use_wandb)
# save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,

View File

@ -82,7 +82,6 @@ def run_dist(rank, world_size, port, strategy):
run_test_checkpoint(strategy)
@pytest.mark.skip('temporarily skip until refactor strategy unwrap')
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])