[hotfix] fix lora load (#6231)

* [hotfix] fix lora load

* [hotfix] fix hp load

* accelerate deepseek loading
pull/6236/head
Hongxin Liu 2025-03-01 19:04:14 +08:00 committed by GitHub
parent f32861ccc5
commit 56fe130b15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 146 additions and 38 deletions

View File

@ -257,7 +257,7 @@ def train(args) -> None:
) )
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
booster.load_model(model, args.pretrained) booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)
coordinator.print_on_master( coordinator.print_on_master(
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"

View File

@ -85,11 +85,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if use_async: if use_async:
from colossalai.utils.safetensors import save from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for k, v in state_dict.items(): for k, v in state_dict.items():
self.pinned_state_dicts[id(model)][k].copy_(v) self.pinned_state_dicts[hash(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[id(model)][k] state_dict[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, state_dict) writer = save(checkpoint, state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)
else: else:
@ -172,9 +172,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True) Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
if use_async and self.coordinator.is_master(): if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {} self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)] pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else: else:
pinned_state_dicts = None pinned_state_dicts = None
state_dict_shard = model.state_dict_shard( state_dict_shard = model.state_dict_shard(

View File

@ -26,6 +26,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
@ -225,7 +226,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if isinstance(model, DDP): if isinstance(model, DDP):
model = model.module model = model.module
if unwrap_peft and isinstance(model, PeftModel): if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model() model = PeftUnwrapMixin(model)
return model return model
def _force_wait_all_gather(self): def _force_wait_all_gather(self):

View File

@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -201,7 +202,7 @@ class TorchDDPModel(ModelWrapper):
def unwrap(self, unwrap_peft: bool = True) -> nn.Module: def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
model = self.module.module model = self.module.module
if unwrap_peft and isinstance(model, PeftModel): if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model() model = PeftUnwrapMixin(model)
return model return model

View File

@ -103,11 +103,11 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
if use_async: if use_async:
from colossalai.utils.safetensors import save from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items(): for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v) self.pinned_state_dicts[hash(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k] full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, full_model_state) writer = save(checkpoint, full_model_state)
self.async_writers.append(writer) self.async_writers.append(writer)
else: else:
@ -186,9 +186,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
state_dict = model.unwrap().state_dict() state_dict = model.unwrap().state_dict()
if use_async and self.coordinator.is_master(): if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {} self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)] pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else: else:
pinned_state_dicts = None pinned_state_dicts = None
state_dict_shard = utils.shard_model_checkpoint( state_dict_shard = utils.shard_model_checkpoint(

View File

@ -60,9 +60,9 @@ class GeneralCheckpointIO(CheckpointIO):
if use_async: if use_async:
from colossalai.utils.safetensors import move_and_save from colossalai.utils.safetensors import move_and_save
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
self.async_writers.append(writer) self.async_writers.append(writer)
else: else:
# save the checkpoint # save the checkpoint
@ -234,7 +234,7 @@ class GeneralCheckpointIO(CheckpointIO):
index_file = CheckpointIndexFile(checkpoint_path) index_file = CheckpointIndexFile(checkpoint_path)
if use_async: if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None) pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard, sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path, checkpoint=checkpoint_path,
@ -243,7 +243,7 @@ class GeneralCheckpointIO(CheckpointIO):
is_master=True, is_master=True,
pinned_state_dict=pinned_state_dict, pinned_state_dict=pinned_state_dict,
) )
self.pinned_state_dicts[id(model)] = new_pinned_state_dict self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
self.async_writers.extend(writers) self.async_writers.extend(writers)
else: else:
# Save shards of optimizer states. # Save shards of optimizer states.

View File

@ -249,9 +249,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Only devices with tp_rank == 0 are responsible for model saving. # Only devices with tp_rank == 0 are responsible for model saving.
control_saving = self.tp_rank == 0 and self.sp_rank == 0 control_saving = self.tp_rank == 0 and self.sp_rank == 0
if control_saving and use_async: if control_saving and use_async:
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {} self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)] pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else: else:
pinned_state_dicts = None pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder( state_dict_shard = HybridParallelCheckpointIO._model_sharder(
@ -789,11 +789,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if use_async: if use_async:
from colossalai.utils.safetensors import save from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items(): for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param) self.pinned_state_dicts[hash(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name] state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=state_dict) writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)
else: else:
@ -811,11 +811,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if use_async: if use_async:
from colossalai.utils.safetensors import save from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts: if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
for name, param in complete_state_dict.items(): for name, param in complete_state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param) self.pinned_state_dicts[hash(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict) writer = save(path=checkpoint, state_dict=complete_state_dict)
self.async_writers.append(writer) self.async_writers.append(writer)
else: else:

View File

@ -701,15 +701,18 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
all_param = None all_param = None
# gather param from every ep rank # gather param from every ep rank
# dist.all_gather(all_param, param, group=ep_group) # dist.all_gather(all_param, param, group=ep_group)
dist.gather(param, all_param, group=ep_group) dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
if ep_rank == 0: if ep_rank == 0:
all_param = torch.cat(all_param, dim=0) all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu() state_dict[name] = all_param.cpu()
if self.pp_size > 1: if self.pp_size > 1:
if self.dp_rank == 0: if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)] if self.pp_rank == 0:
dist.gather_object(state_dict, out, group=self.pp_group) out = [None for _ in range(self.pp_size)]
else:
out = None
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
if self.pp_rank == 0: if self.pp_rank == 0:
new_state_dict = {} new_state_dict = {}
for o in out: for o in out:

View File

@ -20,6 +20,7 @@ from torch.optim import Optimizer
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError: except ImportError:
return return
if isinstance(model, PeftUnwrapMixin):
model = model.base_model
if not isinstance(model, PreTrainedModel): if not isinstance(model, PreTrainedModel):
return return
@ -692,6 +695,9 @@ def load_state_dict_into_model(
state_dict (dict): a dict containing parameters and state_dict (dict): a dict containing parameters and
persistent buffers. persistent buffers.
""" """
if isinstance(model, PeftUnwrapMixin):
state_dict = model.patch_state_dict(state_dict)
model = model.base_model
if not isinstance(state_dict, Mapping): if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))

View File

@ -1,5 +1,102 @@
import re
from typing import Dict, Set
import torch
import torch.nn as nn import torch.nn as nn
from peft import PeftModel from peft import PeftModel, PeftType
def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k for k in names if "lora_" in k}
elif bias == "all":
to_return = {k for k in names if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = set()
for k in names:
if "lora_" in k:
to_return.add(k)
bias_name = k.split("lora_")[0] + "bias"
if bias_name in names:
to_return.add(bias_name)
else:
raise NotImplementedError
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k) for k in to_return}
to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
return to_return
class PeftUnwrapMixin:
def __init__(self, peft_model: PeftModel):
self.base_model = peft_model.get_base_model()
# peft does not affect buffers
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
potential_lora_weights = set()
for n in self.lora_layers:
potential_lora_weights.add(f"{n}.weight")
potential_lora_weights.add(f"{n}.bias")
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}
def named_parameters(self):
for n, p in self.base_model.named_parameters():
if n in self.lora_param_to_origin_param:
n = self.lora_param_to_origin_param[n]
yield n, p
def named_buffers(self):
return self.base_model.named_buffers()
@property
def _modules(self):
return self.base_model._modules
@property
def _non_persistent_buffers_set(self):
return self.base_model._non_persistent_buffers_set
def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
new_state_dict = {}
for k, v in state_dict.items():
if k in self.origin_param_to_lora_param:
k = self.origin_param_to_lora_param[k]
new_state_dict[k] = v
return new_state_dict
def state_dict(self):
state_dict = {}
for k, v in self.base_model.state_dict().items():
if k in self.lora_param_to_origin_param:
k = self.lora_param_to_origin_param[k]
state_dict[k] = v
return state_dict
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
state_dict = self.patch_state_dict(state_dict)
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)
def __hash__(self):
return hash(self.base_model)
class ModelWrapper(nn.Module): class ModelWrapper(nn.Module):
@ -23,7 +120,7 @@ class ModelWrapper(nn.Module):
else: else:
model = self.module model = self.module
if unwrap_peft and isinstance(model, PeftModel): if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model() model = PeftUnwrapMixin(model)
return model return model
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):