mirror of https://github.com/hpcaitech/ColossalAI
368 lines
14 KiB
Python
Executable File
368 lines
14 KiB
Python
Executable File
"""
|
|
LORA utils
|
|
"""
|
|
|
|
import dataclasses
|
|
import math
|
|
import warnings
|
|
from typing import List, Optional, Union
|
|
|
|
import loralib as lora
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
logger = get_dist_logger()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LoraManager:
|
|
able_to_merge: bool = True
|
|
|
|
|
|
lora_manager = LoraManager()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LoraConfig:
|
|
r: int = 0
|
|
lora_alpha: int = 32
|
|
linear_lora_dropout: float = 0.1
|
|
embedding_lora_dropout: float = 0.0
|
|
lora_train_bias: str = "none"
|
|
lora_initialization_method: str = "kaiming_uniform"
|
|
target_modules: List = None
|
|
|
|
@classmethod
|
|
def from_file(cls, config_file: str):
|
|
import json
|
|
|
|
with open(config_file, "r") as f:
|
|
config = json.load(f)
|
|
return cls(**config)
|
|
|
|
|
|
class LoraBase(lora.LoRALayer, nn.Module):
|
|
def __init__(
|
|
self,
|
|
r: int = 0,
|
|
lora_alpha: int = 32,
|
|
lora_dropout: float = 0.1,
|
|
lora_initialization_method: str = "kaiming_uniform",
|
|
):
|
|
nn.Module.__init__(self)
|
|
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
|
self.r = r
|
|
self.lora_alpha = lora_alpha
|
|
self.lora_dropout = nn.Dropout(lora_dropout)
|
|
self.merged = False
|
|
self.lora_initialization_method = lora_initialization_method
|
|
self.weight = None
|
|
self.bias = None
|
|
self.lora_A = None
|
|
self.lora_B = None
|
|
|
|
def reset_parameters(self):
|
|
if hasattr(self, "lora_A"):
|
|
if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != (
|
|
self.out_features,
|
|
self.in_features,
|
|
):
|
|
# Initialize A with the default values for nn.Linear and set B to zero.
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
nn.init.zeros_(self.lora_B)
|
|
elif self.lora_initialization_method == "PiSSA":
|
|
# PiSSA method in this paper: https://arxiv.org/abs/2404.02948
|
|
# Assume the SVD of the original weights is W = USV^T
|
|
# Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W
|
|
# Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively
|
|
# self.scaling = 1.
|
|
# SVD
|
|
U, S, Vh = torch.svd_lowrank(
|
|
self.weight.to(torch.float32).data, self.r, niter=4
|
|
) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features]
|
|
# weight_backup = self.weight.clone()
|
|
|
|
# Initialize A, B
|
|
S = S / self.scaling
|
|
self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous()
|
|
self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous()
|
|
# Initialize weight
|
|
# To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :]
|
|
self.weight.data = (
|
|
((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype)
|
|
)
|
|
self.lora_A.requires_grad = True
|
|
self.lora_B.requires_grad = True
|
|
else:
|
|
raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}")
|
|
|
|
def train(self, mode: bool = True):
|
|
"""
|
|
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
|
|
"""
|
|
|
|
self.training = mode
|
|
if mode and self.merged:
|
|
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
|
raise NotImplementedError("LoRA unmerge is not tested.")
|
|
elif not mode and not self.merged and lora_manager.able_to_merge:
|
|
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
|
# Merge the weights and mark it
|
|
if self.r > 0:
|
|
self.weight.data += self.lora_B @ self.lora_A * self.scaling
|
|
delattr(self, "lora_A")
|
|
delattr(self, "lora_B")
|
|
self.merged = True
|
|
|
|
return self
|
|
|
|
|
|
class LoraLinear(LoraBase):
|
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
|
|
|
def __init__(
|
|
self,
|
|
weight: nn.Parameter,
|
|
bias: Union[nn.Parameter, bool],
|
|
r: int = 0,
|
|
lora_alpha: int = 32,
|
|
lora_dropout: float = 0.0,
|
|
lora_initialization_method: str = "kaiming_uniform",
|
|
):
|
|
super().__init__(
|
|
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
|
|
)
|
|
self.weight = weight
|
|
self.bias = bias
|
|
if bias is True:
|
|
self.bias = nn.Parameter(torch.zeros(weight.shape[0]))
|
|
if bias is not None:
|
|
self.bias.requires_grad = True
|
|
|
|
out_features, in_features = weight.shape
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
|
|
self.lora_initialization_method = lora_initialization_method
|
|
# Actual trainable parameters
|
|
if r > 0:
|
|
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
|
|
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
|
|
self.scaling = self.lora_alpha / self.r
|
|
# Freezing the pre-trained weight matrix
|
|
self.weight.requires_grad = False
|
|
self.reset_parameters()
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if self.r > 0 and not self.merged:
|
|
result = F.linear(x, self.weight, bias=self.bias)
|
|
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
|
return result
|
|
else:
|
|
return F.linear(x, self.weight, bias=self.bias)
|
|
|
|
|
|
class LoraEmbedding(LoraBase):
|
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
|
|
|
def __init__(
|
|
self,
|
|
weight: nn.Parameter,
|
|
r: int = 0,
|
|
lora_alpha: int = 32,
|
|
lora_dropout: float = 0.1,
|
|
num_embeddings: int = None,
|
|
embedding_dim: int = None,
|
|
padding_idx: Optional[int] = None,
|
|
max_norm: Optional[float] = None,
|
|
norm_type: float = 2.0,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
lora_initialization_method: str = "kaiming_uniform",
|
|
):
|
|
super().__init__(
|
|
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
|
|
)
|
|
self.padding_idx = padding_idx
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
self.sparse = sparse
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
|
|
self.weight = weight
|
|
|
|
in_features, out_features = num_embeddings, embedding_dim
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
|
|
self.lora_initialization_method = lora_initialization_method
|
|
|
|
# Actual trainable parameters
|
|
if r > 0:
|
|
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
|
|
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
|
|
self.scaling = self.lora_alpha / self.r
|
|
# Freezing the pre-trained weight matrix
|
|
self.weight.requires_grad = False
|
|
|
|
# reset parameters
|
|
nn.init.zeros_(self.lora_A)
|
|
nn.init.normal_(self.lora_B)
|
|
|
|
def _embed(self, x: torch.Tensor, weight) -> torch.Tensor:
|
|
return F.embedding(
|
|
x,
|
|
weight,
|
|
padding_idx=self.padding_idx,
|
|
max_norm=self.max_norm,
|
|
norm_type=self.norm_type,
|
|
scale_grad_by_freq=self.scale_grad_by_freq,
|
|
sparse=self.sparse,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
base_embedding = self._embed(x, self.weight)
|
|
# base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing
|
|
if self.r > 0 and not self.merged:
|
|
lora_A_embedding = self._embed(x, self.lora_A.t())
|
|
embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling
|
|
return embedding
|
|
else:
|
|
return base_embedding
|
|
|
|
def train(self, mode: bool = True):
|
|
"""
|
|
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
|
|
"""
|
|
|
|
self.training = mode
|
|
if mode and self.merged:
|
|
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
|
raise NotImplementedError("LoRA unmerge is not tested.")
|
|
elif not mode and not self.merged and lora_manager.able_to_merge:
|
|
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
|
# Merge the weights and mark it
|
|
if self.r > 0:
|
|
self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling
|
|
delattr(self, "lora_A")
|
|
delattr(self, "lora_B")
|
|
self.merged = True
|
|
|
|
return self
|
|
|
|
|
|
def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear:
|
|
"""
|
|
Wraps a linear layer with LoRA functionality.
|
|
|
|
Args:
|
|
linear (nn.Linear): The linear layer to be wrapped.
|
|
lora_rank (int): The rank of the LoRA decomposition.
|
|
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
|
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
|
|
|
Returns:
|
|
LoraLinear: The wrapped linear layer with LoRA functionality.
|
|
"""
|
|
assert (
|
|
lora_config.r <= linear.in_features
|
|
), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})"
|
|
bias = None
|
|
if lora_config.lora_train_bias in ["all", "lora"]:
|
|
bias = linear.bias
|
|
if bias is None:
|
|
bias = True
|
|
lora_linear = LoraLinear(
|
|
linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method
|
|
)
|
|
return lora_linear
|
|
|
|
|
|
def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None:
|
|
"""
|
|
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
|
|
|
|
Args:
|
|
module (nn.Module): The module to convert to LoRA form.
|
|
lora_rank (int): The rank of the LoRA approximation.
|
|
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
|
parent_name (str): The name of the parent module.
|
|
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
for name, child in module.named_children():
|
|
if isinstance(child, nn.Linear):
|
|
if lora_config.target_modules is None or any(
|
|
[name in target_module for target_module in lora_config.target_modules]
|
|
):
|
|
if dist.is_initialized() and dist.get_rank() == 0:
|
|
logger.info(f"Converting {parent_name}.{name} to LoRA")
|
|
setattr(module, name, _lora_linear_wrapper(child, lora_config))
|
|
elif isinstance(child, nn.Embedding):
|
|
if lora_config.target_modules is None or any(
|
|
[name in target_module for target_module in lora_config.target_modules]
|
|
):
|
|
if dist.is_initialized() and dist.get_rank() == 0:
|
|
logger.info(f"Converting {parent_name}.{name} to LoRA")
|
|
setattr(
|
|
module,
|
|
name,
|
|
LoraEmbedding(
|
|
child.weight,
|
|
r=lora_config.r,
|
|
lora_alpha=lora_config.lora_alpha,
|
|
lora_dropout=lora_config.embedding_lora_dropout,
|
|
num_embeddings=child.num_embeddings,
|
|
embedding_dim=child.embedding_dim,
|
|
padding_idx=child.padding_idx,
|
|
max_norm=child.max_norm,
|
|
norm_type=child.norm_type,
|
|
scale_grad_by_freq=child.scale_grad_by_freq,
|
|
sparse=child.sparse,
|
|
lora_initialization_method=lora_config.lora_initialization_method,
|
|
),
|
|
)
|
|
else:
|
|
_convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config)
|
|
|
|
|
|
def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module:
|
|
"""Convert a torch.nn.Module to a LoRA module.
|
|
|
|
Args:
|
|
module (nn.Module): The module to convert.
|
|
lora_rank (int): LoRA rank.
|
|
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
|
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
|
|
|
Returns:
|
|
nn.Module: The converted module.
|
|
"""
|
|
if lora_config.r <= 0:
|
|
return module
|
|
# make all parameter not trainable, if lora_train_bias is "all", set bias to trainable
|
|
total_parameter_size = 0
|
|
for name, p in module.named_parameters():
|
|
p.requires_grad = False
|
|
if "bias" in name and lora_config.lora_train_bias == "all":
|
|
p.requires_grad = True
|
|
total_parameter_size += p.numel()
|
|
_convert_to_lora_recursively(module, "", lora_config)
|
|
trainable_parameter_size = 0
|
|
for name, p in module.named_parameters():
|
|
if p.requires_grad == True:
|
|
trainable_parameter_size += p.numel()
|
|
if dist.is_initialized() and dist.get_rank() == 0:
|
|
logger.info(
|
|
f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%"
|
|
)
|
|
return module
|