|
|
@ -1,4 +1,6 @@
|
|
|
|
|
|
|
|
import dataclasses
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
import loralib as lora
|
|
|
|
import loralib as lora
|
|
|
@ -7,6 +9,14 @@ import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
|
|
|
class LoRAManager:
|
|
|
|
|
|
|
|
merge_weights: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LORA_MANAGER = LoRAManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
|
|
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
|
|
|
|
|
|
|
|
|
|
@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
r: int = 0,
|
|
|
|
r: int = 0,
|
|
|
|
lora_alpha: int = 1,
|
|
|
|
lora_alpha: int = 1,
|
|
|
|
lora_dropout: float = 0.0,
|
|
|
|
lora_dropout: float = 0.0,
|
|
|
|
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
|
|
|
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
|
|
|
merge_weights: bool = True,
|
|
|
|
fan_in_fan_out: bool = False,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
lora.LoRALayer.__init__(
|
|
|
|
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
|
|
|
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self.weight = weight
|
|
|
|
self.weight = weight
|
|
|
|
self.bias = bias
|
|
|
|
self.bias = bias
|
|
|
|
|
|
|
|
|
|
|
@ -53,8 +61,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
def T(w):
|
|
|
|
def T(w):
|
|
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
|
|
|
|
|
|
|
|
nn.Module.train(self, mode)
|
|
|
|
self.training = mode
|
|
|
|
if self.merge_weights and self.merged:
|
|
|
|
if LORA_MANAGER.merge_weights:
|
|
|
|
|
|
|
|
if mode and self.merged:
|
|
|
|
|
|
|
|
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
|
|
|
|
|
|
|
raise NotImplementedError("LoRA unmerge is not tested.")
|
|
|
|
# Make sure that the weights are not merged
|
|
|
|
# Make sure that the weights are not merged
|
|
|
|
if self.r > 0:
|
|
|
|
if self.r > 0:
|
|
|
|
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
|
|
|
|
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
|
|
|
@ -65,13 +76,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
|
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
|
self.merged = False
|
|
|
|
self.merged = False
|
|
|
|
|
|
|
|
elif not mode and not self.merged:
|
|
|
|
def eval(self):
|
|
|
|
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
|
|
|
def T(w):
|
|
|
|
|
|
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.Module.eval(self)
|
|
|
|
|
|
|
|
if self.merge_weights and not self.merged:
|
|
|
|
|
|
|
|
# Merge the weights and mark it
|
|
|
|
# Merge the weights and mark it
|
|
|
|
if self.r > 0:
|
|
|
|
if self.r > 0:
|
|
|
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
|
@ -79,6 +85,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
delattr(self, "lora_B")
|
|
|
|
delattr(self, "lora_B")
|
|
|
|
self.merged = True
|
|
|
|
self.merged = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
|
def T(w):
|
|
|
|
def T(w):
|
|
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
|
|
|
assert (
|
|
|
|
assert (
|
|
|
|
lora_rank <= linear.in_features
|
|
|
|
lora_rank <= linear.in_features
|
|
|
|
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
|
|
|
|
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
|
|
|
|
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
|
|
|
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
|
|
|
|
return lora_linear
|
|
|
|
return lora_linear
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|