diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index b167b5c7a..f3a6901ad 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) + return peft_model.save_pretrained( + checkpoint, + safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), + ) class LowLevelZeroPlugin(DPPluginBase): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index ec7ce7f9a..156a4acf9 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,9 +1,11 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union +import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO @@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) + return peft_model.save_pretrained( + checkpoint, + safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), + ) class TorchDDPModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 3b6917d32..e6abf59e3 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -11,6 +11,7 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -956,4 +957,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) + return peft_model.save_pretrained( + checkpoint, + safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()), + )