Browse Source

[chat]: add lora merge weights config (#4766)

* feat: modify lora merge weights fn

* feat: add lora merge weights config
pull/4773/head^2
Wenhao Chen 1 year ago committed by GitHub
parent
commit
901ab1eedd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 38
      applications/Chat/coati/models/lora.py
  2. 7
      applications/Chat/examples/train_prompts.py
  3. 8
      applications/Chat/examples/train_reward_model.py
  4. 7
      applications/Chat/examples/train_sft.py

38
applications/Chat/coati/models/lora.py

@ -1,4 +1,6 @@
import dataclasses
import math
import warnings
from typing import Optional
import loralib as lora
@ -7,6 +9,14 @@ import torch.nn as nn
import torch.nn.functional as F
@dataclasses.dataclass
class LoRAManager:
merge_weights: bool = False
LORA_MANAGER = LoRAManager()
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
fan_in_fan_out: bool = False,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self.weight = weight
self.bias = bias
@ -53,8 +61,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode)
if self.merge_weights and self.merged:
self.training = mode
if LORA_MANAGER.merge_weights:
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
# Make sure that the weights are not merged
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
@ -65,13 +76,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
elif not mode and not self.merged:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
@ -79,6 +85,8 @@ class LoraLinear(lora.LoRALayer, nn.Module):
delattr(self, "lora_B")
self.merged = True
return self
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
return lora_linear

7
applications/Chat/examples/train_prompts.py

@ -192,6 +192,12 @@ def main(args):
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
actor.eval()
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
@ -227,6 +233,7 @@ if __name__ == "__main__":
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=1e-7)
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9)

8
applications/Chat/examples/train_reward_model.py

@ -157,6 +157,13 @@ def train(args):
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
@ -186,6 +193,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"])
parser.add_argument("--log_dir", default="logs", type=str)

7
applications/Chat/examples/train_sft.py

@ -177,6 +177,12 @@ def train(args):
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
@ -204,6 +210,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default="logs", type=str)

Loading…
Cancel
Save