diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py index b20b02d3d..64ebf12f1 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py @@ -9,6 +9,10 @@ from chatgpt.models.base import Actor from chatgpt.models.lora import LoraLinear from torch.optim import Optimizer + +from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils_base import PreTrainedTokenizerBase + import colossalai from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper @@ -143,7 +147,7 @@ class ColossalAIStrategy(DDPStrategy): return model.module return model - def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: unwrapped_model = self._unwrap_model(model) # TODO : better way to get torch model from gemini model # to get torch model from gemini model @@ -159,10 +163,16 @@ class ColossalAIStrategy(DDPStrategy): module.merge_weights=True module.eval() # get state_dict and save - state_dict = unwrapped_model.state_dict() - if only_rank0 and dist.get_rank() != 0: - return - torch.save(state_dict, path) + + if not isinstance(self.model, PreTrainedModel): + state_dict = unwrapped_model.state_dict() + if only_rank0 and dist.get_rank() != 0: + return + torch.save(state_dict, path) + else: + self.model.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: if only_rank0: diff --git a/applications/ChatGPT/requirements.txt b/applications/ChatGPT/requirements.txt index 15a960c2c..346911192 100644 --- a/applications/ChatGPT/requirements.txt +++ b/applications/ChatGPT/requirements.txt @@ -3,5 +3,5 @@ tqdm datasets loralib colossalai>=0.2.4 -torch +torch==1.12.1 langchain