mirror of https://github.com/hpcaitech/ColossalAI
fix save_model inin naive and ddp strategy (#3436)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>pull/3445/head
parent
1beb85cc25
commit
773955abfa
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
|
@ -5,12 +7,13 @@ import numpy as np
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor
|
||||
from coati.models.base import LM, Actor, RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .base import Strategy
|
||||
from .naive import NaiveStrategy
|
||||
|
@ -72,16 +75,31 @@ class DDPStrategy(NaiveStrategy):
|
|||
model: DDP = Strategy._unwrap_actor(actor)
|
||||
return model.module
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return None
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
|
||||
if isinstance(model, RewardModel):
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
model = model.model.module
|
||||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
if isinstance(model, LM):
|
||||
model = model.model
|
||||
model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
except AttributeError:
|
||||
state_dict = model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from coati.models.base import LM, RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .base import Strategy
|
||||
|
||||
|
@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
|
|||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
torch.save(unwrapped_model.state_dict(), path)
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
|
||||
if isinstance(model, RewardModel):
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
if isinstance(model, LM):
|
||||
model = model.model
|
||||
model.save_pretrained(path)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
except AttributeError:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
|
|
Loading…
Reference in New Issue