diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index 2114913e1..e9bd7b2ed 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -1,4 +1,6 @@ +import dataclasses import math +import warnings from typing import Optional import loralib as lora @@ -7,6 +9,14 @@ import torch.nn as nn import torch.nn.functional as F +@dataclasses.dataclass +class LoRAManager: + merge_weights: bool = False + + +LORA_MANAGER = LoRAManager() + + class LoraLinear(lora.LoRALayer, nn.Module): """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" @@ -17,13 +27,11 @@ class LoraLinear(lora.LoRALayer, nn.Module): r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) - merge_weights: bool = True, + # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + fan_in_fan_out: bool = False, ): nn.Module.__init__(self) - lora.LoRALayer.__init__( - self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights - ) + lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) self.weight = weight self.bias = bias @@ -53,31 +61,31 @@ class LoraLinear(lora.LoRALayer, nn.Module): def T(w): return w.T if self.fan_in_fan_out else w - nn.Module.train(self, mode) - if self.merge_weights and self.merged: - # Make sure that the weights are not merged - if self.r > 0: - if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): - # FIXME(csric): temporary fix - self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) - self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) - self.reset_parameters() - else: - self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling - self.merged = False - - def eval(self): - def T(w): - return w.T if self.fan_in_fan_out else w - - nn.Module.eval(self) - if self.merge_weights and not self.merged: - # Merge the weights and mark it - if self.r > 0: - self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - delattr(self, "lora_A") - delattr(self, "lora_B") - self.merged = True + self.training = mode + if LORA_MANAGER.merge_weights: + if mode and self.merged: + warnings.warn("Invoke module.train() would unmerge LoRA weights.") + raise NotImplementedError("LoRA unmerge is not tested.") + # Make sure that the weights are not merged + if self.r > 0: + if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"): + # FIXME(csric): temporary fix + self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features))) + self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r))) + self.reset_parameters() + else: + self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + self.merged = False + elif not mode and not self.merged: + warnings.warn("Invoke module.eval() would merge LoRA weights.") + # Merge the weights and mark it + if self.r > 0: + self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + delattr(self, "lora_A") + delattr(self, "lora_B") + self.merged = True + + return self def forward(self, x: torch.Tensor): def T(w): @@ -96,7 +104,7 @@ def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: assert ( lora_rank <= linear.in_features ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" - lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) + lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank) return lora_linear diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index de2a33263..a8ab15eeb 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -192,6 +192,12 @@ def main(args): use_wandb=args.use_wandb, ) + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + actor.eval() # save model checkpoint after fitting strategy.save_model(actor, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks @@ -227,6 +233,7 @@ if __name__ == "__main__": parser.add_argument("--ptx_batch_size", type=int, default=1) parser.add_argument("--experience_batch_size", type=int, default=8) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=1e-7) parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--ptx_coef", type=float, default=0.9) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index c9095b365..c1be51f2f 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -157,6 +157,13 @@ def train(args): log_dir=args.log_dir, use_wandb=args.use_wandb, ) + + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() # save model checkpoint after fitting on only rank0 strategy.save_model(model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks @@ -186,6 +193,7 @@ if __name__ == "__main__": parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=9e-6) parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) parser.add_argument("--log_dir", default="logs", type=str) diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index a34661762..4f36791be 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -177,6 +177,12 @@ def train(args): use_wandb=args.use_wandb, ) + if args.lora_rank > 0 and args.merge_lora_weights: + from coati.models.lora import LORA_MANAGER + + # NOTE: set model to eval to merge LoRA weights + LORA_MANAGER.merge_weights = True + model.eval() # save model checkpoint after fitting on only rank0 strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks @@ -204,6 +210,7 @@ if __name__ == "__main__": parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--merge_lora_weights", type=bool, default=True) parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--log_dir", default="logs", type=str)