mirror of https://github.com/hpcaitech/ColossalAI
fix torch version (#3225)
parent
fa97a9cab4
commit
bbac6760e5
|
@ -9,6 +9,10 @@ from chatgpt.models.base import Actor
|
||||||
from chatgpt.models.lora import LoraLinear
|
from chatgpt.models.lora import LoraLinear
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||||
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||||
|
@ -143,7 +147,7 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
return model.module
|
return model.module
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||||
unwrapped_model = self._unwrap_model(model)
|
unwrapped_model = self._unwrap_model(model)
|
||||||
# TODO : better way to get torch model from gemini model
|
# TODO : better way to get torch model from gemini model
|
||||||
# to get torch model from gemini model
|
# to get torch model from gemini model
|
||||||
|
@ -159,10 +163,16 @@ class ColossalAIStrategy(DDPStrategy):
|
||||||
module.merge_weights=True
|
module.merge_weights=True
|
||||||
module.eval()
|
module.eval()
|
||||||
# get state_dict and save
|
# get state_dict and save
|
||||||
state_dict = unwrapped_model.state_dict()
|
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
if not isinstance(self.model, PreTrainedModel):
|
||||||
return
|
state_dict = unwrapped_model.state_dict()
|
||||||
torch.save(state_dict, path)
|
if only_rank0 and dist.get_rank() != 0:
|
||||||
|
return
|
||||||
|
torch.save(state_dict, path)
|
||||||
|
else:
|
||||||
|
self.model.save_pretrained(path)
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||||
if only_rank0:
|
if only_rank0:
|
||||||
|
|
|
@ -3,5 +3,5 @@ tqdm
|
||||||
datasets
|
datasets
|
||||||
loralib
|
loralib
|
||||||
colossalai>=0.2.4
|
colossalai>=0.2.4
|
||||||
torch
|
torch==1.12.1
|
||||||
langchain
|
langchain
|
||||||
|
|
Loading…
Reference in New Issue