mirror of https://github.com/hpcaitech/ColossalAI
pre-commit fix
parent
b10339df7c
commit
6d6cafabe2
|
@ -290,8 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
|
||||
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
)
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
|
|
@ -5,8 +5,8 @@ import torch.nn as nn
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
@ -136,9 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
|
||||
peft_model.state_dict()))
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
|
|
@ -957,6 +957,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
|
||||
peft_model.state_dict()))
|
||||
return peft_model.save_pretrained(
|
||||
checkpoint,
|
||||
safe_serialization=use_safetensors,
|
||||
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue