mirror of https://github.com/hpcaitech/ColossalAI
[chat]: add lora merge weights config (#4766)
* feat: modify lora merge weights fn * feat: add lora merge weights configpull/4773/head^2
parent
493a5efeab
commit
901ab1eedd
|
@ -1,4 +1,6 @@
|
|||
import dataclasses
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import loralib as lora
|
||||
|
@ -7,6 +9,14 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoRAManager:
|
||||
merge_weights: bool = False
|
||||
|
||||
|
||||
LORA_MANAGER = LoRAManager()
|
||||
|
||||
|
||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||
|
||||
|
@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = True,
|
||||
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
fan_in_fan_out: bool = False,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
lora.LoRALayer.__init__(
|
||||
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
|
||||
)
|
||||
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
|
@ -53,31 +61,31 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
|||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
|
||||
# FIXME(csric): temporary fix
|
||||
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
self.training = mode
|
||||
if LORA_MANAGER.merge_weights:
|
||||
if mode and self.merged:
|
||||
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
||||
raise NotImplementedError("LoRA unmerge is not tested.")
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
|
||||
# FIXME(csric): temporary fix
|
||||
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
|
||||
self.reset_parameters()
|
||||
else:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
elif not mode and not self.merged:
|
||||
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
delattr(self, "lora_A")
|
||||
delattr(self, "lora_B")
|
||||
self.merged = True
|
||||
|
||||
def eval(self):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
delattr(self, "lora_A")
|
||||
delattr(self, "lora_B")
|
||||
self.merged = True
|
||||
return self
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def T(w):
|
||||
|
@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
|||
assert (
|
||||
lora_rank <= linear.in_features
|
||||
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
|
||||
return lora_linear
|
||||
|
||||
|
||||
|
|
|
@ -192,6 +192,12 @@ def main(args):
|
|||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
actor.eval()
|
||||
# save model checkpoint after fitting
|
||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
|
@ -227,6 +233,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=1e-7)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||
|
|
|
@ -157,6 +157,13 @@ def train(args):
|
|||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
|
@ -186,6 +193,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=9e-6)
|
||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
|
|
|
@ -177,6 +177,12 @@ def train(args):
|
|||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
|
@ -204,6 +210,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
|
|
Loading…
Reference in New Issue