Merge pull request #6096 from BurkeHulk/hotfix/lora_ckpt

[hotfix] fix lora ckpt saving format
pull/6078/head
Hanks 1 month ago committed by GitHub
commit dee63cc5ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -290,7 +290,11 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):

@ -1,9 +1,11 @@
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class TorchDDPModel(ModelWrapper): class TorchDDPModel(ModelWrapper):

@ -11,6 +11,7 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -956,4 +957,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance( assert isinstance(
peft_model, PeftModel peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving." ), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)

Loading…
Cancel
Save