fix save_model inin naive and ddp strategy (#3436)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3445/head
Yuanchen 2023-04-04 15:30:01 +08:00 committed by GitHub
parent 1beb85cc25
commit 773955abfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 12 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional
import os
import random
@ -5,12 +7,13 @@ import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from coati.models.base import Actor
from coati.models.base import LM, Actor, RewardModel
from coati.models.lora import LoraLinear
from coati.replay_buffer import ReplayBuffer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy
from .naive import NaiveStrategy
@ -72,17 +75,32 @@ class DDPStrategy(NaiveStrategy):
model: DDP = Strategy._unwrap_actor(actor)
return model.module
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
if only_rank0 and dist.get_rank() != 0:
return None
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
if only_rank0 and dist.get_rank() != 0:
return
model = model.model.module
state_dict = model.state_dict()
torch.save(state_dict, path)
if isinstance(model, RewardModel):
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
if only_rank0 and dist.get_rank() != 0:
return
torch.save(state_dict, path)
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
if only_rank0 and dist.get_rank() != 0:

View File

@ -1,11 +1,14 @@
from typing import Any
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from coati.replay_buffer import ReplayBuffer
from coati.models.base import LM, RewardModel
from coati.models.lora import LoraLinear
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from .base import Strategy
@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
pin_memory=pin_memory,
collate_fn=replay_buffer.collate_fn)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
unwrapped_model = self._unwrap_model(model)
torch.save(unwrapped_model.state_dict(), path)
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
for module in model.modules():
if isinstance(module, LoraLinear):
module.merge_weights = True
module.eval()
if isinstance(model, RewardModel):
state_dict = model.state_dict()
torch.save(state_dict, path)
else:
try:
if isinstance(model, LM):
model = model.model
model.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
except AttributeError:
state_dict = model.state_dict()
torch.save(state_dict, path)
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
unwrapped_model = self._unwrap_model(model)