Merge pull request #6096 from BurkeHulk/hotfix/lora_ckpt

[hotfix] fix lora ckpt saving format
pull/6078/head
Hanks 1 month ago committed by GitHub
commit dee63cc5ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class LowLevelZeroPlugin(DPPluginBase):

@ -1,9 +1,11 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class TorchDDPModel(ModelWrapper):

@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -956,4 +957,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)

Loading…
Cancel
Save