ColossalAI/applications/Chat/coati/models/lora.py

150 lines
5.5 KiB
Python
Raw Normal View History

2023-03-28 12:25:36 +00:00
import math
from typing import Optional
import loralib as lora
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
def __init__(
self,
weight: nn.Parameter,
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights)
self.weight = weight
self.bias = bias
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.train(self, mode)
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
[chat] add distributed PPO trainer (#3740) * Detached ppo (#9) * run the base * working on dist ppo * sync * detached trainer * update detached trainer. no maker update function * facing init problem * 1 maker 1 trainer detached run. but no model update * facing cuda problem * fix save functions * verified maker update * nothing * add ignore * analyize loss issue * remove some debug codes * facing 2m1t stuck issue * 2m1t verified * do not use torchrun * working on 2m2t * working on 2m2t * initialize strategy in ray actor env * facing actor's init order issue * facing ddp model update issue (need unwarp ddp) * unwrap ddp actor * checking 1m2t stuck problem * nothing * set timeout for trainer choosing. It solves the stuck problem! * delete some debug output * rename to sync with upstream * rename to sync with upstream * coati rename * nothing * I am going to detach the replaybuffer from trainer and make it a Ray Actor. Two benefits: 1. support TP trainer. 2. asynchronized buffer operations * experience_maker_holder performs target-revolving _send_experience() instead of length comparison. * move code to ray subfolder * working on pipeline inference * apply comments * working on pipeline strategy. in progress. * remove pipeline code. clean this branch * update remote parameters by state_dict. no test * nothing * state_dict sharding transfer * merge debug branch * gemini _unwrap_model fix * simplify code * simplify code & fix LoRALinear AttributeError * critic unwrapped state_dict --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add perfomance evaluator and fix bugs (#10) * [chat] add performance evaluator for ray * [chat] refactor debug arg * [chat] support hf config * [chat] fix generation * [chat] add 1mmt dummy example * [chat] fix gemini ckpt * split experience to send (#11) Co-authored-by: csric <richcsr256@gmail.com> * [chat] refactor trainer and maker (#12) * [chat] refactor experience maker holder * [chat] refactor model init * [chat] refactor trainer args * [chat] refactor model init * [chat] refactor trainer * [chat] refactor experience sending logic and training loop args (#13) * [chat] refactor experience send logic * [chat] refactor trainer * [chat] refactor trainer * [chat] refactor experience maker * [chat] refactor pbar * [chat] refactor example folder (#14) * [chat] support quant (#15) * [chat] add quant * [chat] add quant example * prompt example (#16) * prompt example * prompt load csv data * remove legacy try --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] add mmmt dummy example and refactor experience sending (#17) * [chat] add mmmt dummy example * [chat] refactor naive strategy * [chat] fix struck problem * [chat] fix naive strategy * [chat] optimize experience maker sending logic * [chat] refactor sending assignment * [chat] refactor performance evaluator (#18) * Prompt Example & requires_grad state_dict & sharding state_dict (#19) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design --------- Co-authored-by: csric <richcsr256@gmail.com> * state_dict sending adapts to new unwrap function (#20) * prompt example * prompt load csv data * remove legacy try * maker models require_grad set to False * working on zero redundancy update * mmmt_prompt example; naive strategy requires_grad state_dict & sharding; maker model requires_no_grad. * remove legacy examples * remove legacy examples * remove replay buffer tp state. bad design * opt benchmark * better script * nothing * [chat] strategy refactor unwrap model * [chat] strategy refactor save model * [chat] add docstr * [chat] refactor trainer save model * [chat] fix strategy typing * [chat] refactor trainer save model * [chat] update readme * [chat] fix unit test * working on lora reconstruction * state_dict sending adapts to new unwrap function * remove comments --------- Co-authored-by: csric <richcsr256@gmail.com> Co-authored-by: ver217 <lhx0217@gmail.com> * [chat-ray] add readme (#21) * add readme * transparent graph * add note background --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] get images from url (#22) * Refactor/chat ray (#23) * [chat] lora add todo * [chat] remove unused pipeline strategy * [chat] refactor example structure * [chat] setup ci for ray * [chat-ray] Support LoRA trainer. LoRA weights reconstruction. (#24) * lora support prototype * lora support * 1mmt lora & remove useless code --------- Co-authored-by: csric <richcsr256@gmail.com> * [chat] fix test ci for ray * [chat] fix test ci requirements for ray * [chat] fix ray runtime env * [chat] fix ray runtime env * [chat] fix example ci docker args * [chat] add debug info in trainer * [chat] add nccl debug info * [chat] skip ray test * [doc] fix typo --------- Co-authored-by: csric <59389055+CsRic@users.noreply.github.com> Co-authored-by: csric <richcsr256@gmail.com>
2023-06-07 02:41:16 +00:00
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
2023-03-28 12:25:36 +00:00
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
nn.Module.eval(self)
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, lora_linear_wrapper(child, lora_rank))
else:
convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
module (nn.Module): The module to convert.
lora_rank (int): LoRA rank.
Returns:
nn.Module: The converted module.
"""
if lora_rank <= 0:
return module
convert_to_lora_recursively(module, lora_rank)
lora.mark_only_lora_as_trainable(module, lora_train_bias)
return module
2023-03-28 12:25:36 +00:00
class LoRAModule(nn.Module):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
2023-04-20 09:22:15 +00:00
This class will convert all torch.nn.Linear layer to LoraLinear layer.
2023-03-28 12:25:36 +00:00
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias
def convert_to_lora(self) -> None:
convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)