From b10339df7cfac1eee6945df37a80fd1c38f42289 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Mon, 21 Oct 2024 13:55:43 +0800 Subject: [PATCH] fix lora ckpt save format (ColoTensor to Tensor) --- colossalai/booster/plugin/low_level_zero_plugin.py | 3 ++- colossalai/booster/plugin/torch_ddp_plugin.py | 6 +++++- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 5 ++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index b167b5c7a..97fabe63a 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -290,7 +290,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors, + state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())) class LowLevelZeroPlugin(DPPluginBase): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index ec7ce7f9a..aa4d35cd4 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,10 +1,12 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union +import torch import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from torch.utils._pytree import tree_map from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator @@ -134,7 +136,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, + peft_model.state_dict())) class TorchDDPModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 3b6917d32..4ca1353d8 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -11,6 +11,7 @@ import torch.distributed as dist import torch.nn as nn from torch.distributed import ProcessGroup from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -956,4 +957,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): assert isinstance( peft_model, PeftModel ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors, + state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, + peft_model.state_dict()))