mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix lora load (#6231)
* [hotfix] fix lora load * [hotfix] fix hp load * accelerate deepseek loadingpull/6236/head
parent
f32861ccc5
commit
56fe130b15
|
@ -257,7 +257,7 @@ def train(args) -> None:
|
|||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
booster.load_model(model, args.pretrained)
|
||||
booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
|
|
|
@ -85,11 +85,11 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||
for k, v in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
state_dict[k] = self.pinned_state_dicts[id(model)][k]
|
||||
self.pinned_state_dicts[hash(model)][k].copy_(v)
|
||||
state_dict[k] = self.pinned_state_dicts[hash(model)][k]
|
||||
writer = save(checkpoint, state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
|
@ -172,9 +172,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = model.state_dict_shard(
|
||||
|
|
|
@ -26,6 +26,7 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
|
|||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.model import PeftUnwrapMixin
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
|
@ -225,7 +226,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
model = PeftUnwrapMixin(model)
|
||||
return model
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.utils.data import DataLoader
|
|||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.model import PeftUnwrapMixin
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -201,7 +202,7 @@ class TorchDDPModel(ModelWrapper):
|
|||
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
|
||||
model = self.module.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
model = PeftUnwrapMixin(model)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -103,11 +103,11 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
|
||||
for k, v in full_model_state.items():
|
||||
self.pinned_state_dicts[id(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
|
||||
self.pinned_state_dicts[hash(model)][k].copy_(v)
|
||||
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
|
||||
writer = save(checkpoint, full_model_state)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
|
@ -186,9 +186,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
state_dict = model.unwrap().state_dict()
|
||||
|
||||
if use_async and self.coordinator.is_master():
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = utils.shard_model_checkpoint(
|
||||
|
|
|
@ -60,9 +60,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
if use_async:
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
# save the checkpoint
|
||||
|
@ -234,7 +234,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
|
||||
if use_async:
|
||||
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
|
||||
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
|
||||
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint_path,
|
||||
|
@ -243,7 +243,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
is_master=True,
|
||||
pinned_state_dict=pinned_state_dict,
|
||||
)
|
||||
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
|
||||
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
# Save shards of optimizer states.
|
||||
|
|
|
@ -249,9 +249,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
control_saving = self.tp_rank == 0 and self.sp_rank == 0
|
||||
if control_saving and use_async:
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
|
||||
|
@ -789,11 +789,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||
for name, param in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
self.pinned_state_dicts[hash(model)][name].copy_(param)
|
||||
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
|
@ -811,11 +811,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
if hash(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
for name, param in complete_state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
self.pinned_state_dicts[hash(model)][name].copy_(param)
|
||||
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=complete_state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
|
|
|
@ -701,15 +701,18 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
|||
all_param = None
|
||||
# gather param from every ep rank
|
||||
# dist.all_gather(all_param, param, group=ep_group)
|
||||
dist.gather(param, all_param, group=ep_group)
|
||||
dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
state_dict[name] = all_param.cpu()
|
||||
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.gather_object(state_dict, out, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
else:
|
||||
out = None
|
||||
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
new_state_dict = {}
|
||||
for o in out:
|
||||
|
|
|
@ -20,6 +20,7 @@ from torch.optim import Optimizer
|
|||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface.model import PeftUnwrapMixin
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
|
@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
|
|||
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
|
||||
except ImportError:
|
||||
return
|
||||
if isinstance(model, PeftUnwrapMixin):
|
||||
model = model.base_model
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
return
|
||||
|
||||
|
@ -692,6 +695,9 @@ def load_state_dict_into_model(
|
|||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
"""
|
||||
if isinstance(model, PeftUnwrapMixin):
|
||||
state_dict = model.patch_state_dict(state_dict)
|
||||
model = model.base_model
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
|
||||
|
||||
|
|
|
@ -1,5 +1,102 @@
|
|||
import re
|
||||
from typing import Dict, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft import PeftModel
|
||||
from peft import PeftModel, PeftType
|
||||
|
||||
|
||||
def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
|
||||
config = model.peft_config[adapter_name]
|
||||
if config.peft_type != PeftType.LORA:
|
||||
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
|
||||
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
|
||||
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
|
||||
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
|
||||
bias = config.bias
|
||||
if bias == "none":
|
||||
to_return = {k for k in names if "lora_" in k}
|
||||
elif bias == "all":
|
||||
to_return = {k for k in names if "lora_" in k or "bias" in k}
|
||||
elif bias == "lora_only":
|
||||
to_return = set()
|
||||
for k in names:
|
||||
if "lora_" in k:
|
||||
to_return.add(k)
|
||||
bias_name = k.split("lora_")[0] + "bias"
|
||||
if bias_name in names:
|
||||
to_return.add(bias_name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
|
||||
if config.use_dora:
|
||||
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
|
||||
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
|
||||
# we want the state_dict format not to change, we remove the "weight" part.
|
||||
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
|
||||
|
||||
def renamed_dora_weights(k):
|
||||
if k.endswith(new_dora_suffix):
|
||||
k = k[:-7] # remove ".weight"
|
||||
return k
|
||||
|
||||
to_return = {renamed_dora_weights(k) for k in to_return}
|
||||
|
||||
to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
|
||||
return to_return
|
||||
|
||||
|
||||
class PeftUnwrapMixin:
|
||||
def __init__(self, peft_model: PeftModel):
|
||||
self.base_model = peft_model.get_base_model()
|
||||
# peft does not affect buffers
|
||||
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
|
||||
potential_lora_weights = set()
|
||||
for n in self.lora_layers:
|
||||
potential_lora_weights.add(f"{n}.weight")
|
||||
potential_lora_weights.add(f"{n}.bias")
|
||||
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
|
||||
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}
|
||||
|
||||
def named_parameters(self):
|
||||
for n, p in self.base_model.named_parameters():
|
||||
if n in self.lora_param_to_origin_param:
|
||||
n = self.lora_param_to_origin_param[n]
|
||||
yield n, p
|
||||
|
||||
def named_buffers(self):
|
||||
return self.base_model.named_buffers()
|
||||
|
||||
@property
|
||||
def _modules(self):
|
||||
return self.base_model._modules
|
||||
|
||||
@property
|
||||
def _non_persistent_buffers_set(self):
|
||||
return self.base_model._non_persistent_buffers_set
|
||||
|
||||
def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if k in self.origin_param_to_lora_param:
|
||||
k = self.origin_param_to_lora_param[k]
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
for k, v in self.base_model.state_dict().items():
|
||||
if k in self.lora_param_to_origin_param:
|
||||
k = self.lora_param_to_origin_param[k]
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
|
||||
state_dict = self.patch_state_dict(state_dict)
|
||||
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.base_model)
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
|
@ -23,7 +120,7 @@ class ModelWrapper(nn.Module):
|
|||
else:
|
||||
model = self.module
|
||||
if unwrap_peft and isinstance(model, PeftModel):
|
||||
model = model.get_base_model()
|
||||
model = PeftUnwrapMixin(model)
|
||||
return model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue