|
|
|
@ -1,4 +1,6 @@
|
|
|
|
|
import dataclasses |
|
|
|
|
import math |
|
|
|
|
import warnings |
|
|
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
import loralib as lora |
|
|
|
@ -7,6 +9,14 @@ import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
|
|
class LoRAManager: |
|
|
|
|
merge_weights: bool = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LORA_MANAGER = LoRAManager() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLinear(lora.LoRALayer, nn.Module): |
|
|
|
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" |
|
|
|
|
|
|
|
|
@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
|
r: int = 0, |
|
|
|
|
lora_alpha: int = 1, |
|
|
|
|
lora_dropout: float = 0.0, |
|
|
|
|
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) |
|
|
|
|
merge_weights: bool = True, |
|
|
|
|
# Set this to True if the layer to replace stores weight like (fan_in, fan_out) |
|
|
|
|
fan_in_fan_out: bool = False, |
|
|
|
|
): |
|
|
|
|
nn.Module.__init__(self) |
|
|
|
|
lora.LoRALayer.__init__( |
|
|
|
|
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
|
|
|
|
) |
|
|
|
|
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) |
|
|
|
|
self.weight = weight |
|
|
|
|
self.bias = bias |
|
|
|
|
|
|
|
|
@ -53,8 +61,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
|
def T(w): |
|
|
|
|
return w.T if self.fan_in_fan_out else w |
|
|
|
|
|
|
|
|
|
nn.Module.train(self, mode) |
|
|
|
|
if self.merge_weights and self.merged: |
|
|
|
|
self.training = mode |
|
|
|
|
if LORA_MANAGER.merge_weights: |
|
|
|
|
if mode and self.merged: |
|
|
|
|
warnings.warn("Invoke module.train() would unmerge LoRA weights.") |
|
|
|
|
raise NotImplementedError("LoRA unmerge is not tested.") |
|
|
|
|
# Make sure that the weights are not merged |
|
|
|
|
if self.r > 0: |
|
|
|
|
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): |
|
|
|
@ -65,13 +76,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
|
else: |
|
|
|
|
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling |
|
|
|
|
self.merged = False |
|
|
|
|
|
|
|
|
|
def eval(self): |
|
|
|
|
def T(w): |
|
|
|
|
return w.T if self.fan_in_fan_out else w |
|
|
|
|
|
|
|
|
|
nn.Module.eval(self) |
|
|
|
|
if self.merge_weights and not self.merged: |
|
|
|
|
elif not mode and not self.merged: |
|
|
|
|
warnings.warn("Invoke module.eval() would merge LoRA weights.") |
|
|
|
|
# Merge the weights and mark it |
|
|
|
|
if self.r > 0: |
|
|
|
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling |
|
|
|
@ -79,6 +85,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|
|
|
|
delattr(self, "lora_B") |
|
|
|
|
self.merged = True |
|
|
|
|
|
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
def T(w): |
|
|
|
|
return w.T if self.fan_in_fan_out else w |
|
|
|
@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
|
|
|
|
assert ( |
|
|
|
|
lora_rank <= linear.in_features |
|
|
|
|
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" |
|
|
|
|
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) |
|
|
|
|
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank) |
|
|
|
|
return lora_linear |
|
|
|
|
|
|
|
|
|
|
|
|
|
|