|
|
|
@ -9,6 +9,10 @@ from chatgpt.models.base import Actor
|
|
|
|
|
from chatgpt.models.lora import LoraLinear |
|
|
|
|
from torch.optim import Optimizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.nn.optimizer import CPUAdam, HybridAdam |
|
|
|
|
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper |
|
|
|
@ -143,7 +147,7 @@ class ColossalAIStrategy(DDPStrategy):
|
|
|
|
|
return model.module |
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: |
|
|
|
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: |
|
|
|
|
unwrapped_model = self._unwrap_model(model) |
|
|
|
|
# TODO : better way to get torch model from gemini model |
|
|
|
|
# to get torch model from gemini model |
|
|
|
@ -159,10 +163,16 @@ class ColossalAIStrategy(DDPStrategy):
|
|
|
|
|
module.merge_weights=True |
|
|
|
|
module.eval() |
|
|
|
|
# get state_dict and save |
|
|
|
|
state_dict = unwrapped_model.state_dict() |
|
|
|
|
if only_rank0 and dist.get_rank() != 0: |
|
|
|
|
return |
|
|
|
|
torch.save(state_dict, path) |
|
|
|
|
|
|
|
|
|
if not isinstance(self.model, PreTrainedModel): |
|
|
|
|
state_dict = unwrapped_model.state_dict() |
|
|
|
|
if only_rank0 and dist.get_rank() != 0: |
|
|
|
|
return |
|
|
|
|
torch.save(state_dict, path) |
|
|
|
|
else: |
|
|
|
|
self.model.save_pretrained(path) |
|
|
|
|
if tokenizer is not None: |
|
|
|
|
tokenizer.save_pretrained(path) |
|
|
|
|
|
|
|
|
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: |
|
|
|
|
if only_rank0: |
|
|
|
|