fix lora ckpt save format (ColoTensor to Tensor)

pull/6096/head
BurkeHulk 2024-10-21 13:55:43 +08:00
parent 5ddad486ca
commit b10339df7c
3 changed files with 11 additions and 3 deletions

View File

@ -290,7 +290,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
class LowLevelZeroPlugin(DPPluginBase):

View File

@ -1,10 +1,12 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils._pytree import tree_map
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
@ -134,7 +136,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
peft_model.state_dict()))
class TorchDDPModel(ModelWrapper):

View File

@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -956,4 +957,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
peft_model.state_dict()))