[Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style
pull/5958/head
YeAnbang 2024-07-31 14:10:17 +08:00 committed by GitHub
parent 09c5f72595
commit 30f4e31a33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 552 additions and 252 deletions

View File

@ -1,7 +1,7 @@
from .base import BaseModel from .base import BaseModel
from .critic import Critic from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
from .lora import convert_to_lora_module from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel from .reward_model import RewardModel
from .utils import disable_dropout from .utils import disable_dropout
@ -14,6 +14,8 @@ __all__ = [
"ValueLoss", "ValueLoss",
"LogSigLoss", "LogSigLoss",
"LogExpLoss", "LogExpLoss",
"LoraConfig",
"lora_manager",
"convert_to_lora_module", "convert_to_lora_module",
"DpoLoss", "DpoLoss",
"KTOLoss" "generate", "KTOLoss" "generate",

View File

@ -5,10 +5,11 @@ LORA utils
import dataclasses import dataclasses
import math import math
import warnings import warnings
from typing import Optional from typing import List, Optional, Union
import loralib as lora import loralib as lora
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -18,148 +19,349 @@ logger = get_dist_logger()
@dataclasses.dataclass @dataclasses.dataclass
class LoRAManager: class LoraManager:
merge_weights: bool = False able_to_merge: bool = True
LORA_MANAGER = LoRAManager() lora_manager = LoraManager()
class LoraLinear(lora.LoRALayer, nn.Module): @dataclasses.dataclass
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" class LoraConfig:
r: int = 0
lora_alpha: int = 32
linear_lora_dropout: float = 0.1
embedding_lora_dropout: float = 0.0
lora_train_bias: str = "none"
lora_initialization_method: str = "kaiming_uniform"
target_modules: List = None
@classmethod
def from_file(cls, config_file: str):
import json
with open(config_file, "r") as f:
config = json.load(f)
return cls(**config)
class LoraBase(lora.LoRALayer, nn.Module):
def __init__( def __init__(
self, self,
weight: nn.Parameter,
bias: Optional[nn.Parameter],
r: int = 0, r: int = 0,
lora_alpha: int = 1, lora_alpha: int = 32,
lora_dropout: float = 0.0, lora_dropout: float = 0.1,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out) lora_initialization_method: str = "kaiming_uniform",
fan_in_fan_out: bool = False,
): ):
nn.Module.__init__(self) nn.Module.__init__(self)
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False) lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
self.weight = weight self.r = r
self.bias = bias self.lora_alpha = lora_alpha
self.lora_dropout = nn.Dropout(lora_dropout)
out_features, in_features = weight.shape self.merged = False
self.in_features = in_features self.lora_initialization_method = lora_initialization_method
self.out_features = out_features self.weight = None
self.bias = None
self.fan_in_fan_out = fan_in_fan_out self.lora_A = None
# Actual trainable parameters self.lora_B = None
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)), requires_grad=False)
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self): def reset_parameters(self):
if hasattr(self, "lora_A"): if hasattr(self, "lora_A"):
if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != (
self.out_features,
self.in_features,
):
# Initialize A with the default values for nn.Linear and set B to zero. # Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B) nn.init.zeros_(self.lora_B)
elif self.lora_initialization_method == "PiSSA":
# PiSSA method in this paper: https://arxiv.org/abs/2404.02948
# Assume the SVD of the original weights is W = USV^T
# Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W
# Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively
# self.scaling = 1.
# SVD
U, S, Vh = torch.svd_lowrank(
self.weight.to(torch.float32).data, self.r, niter=4
) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features]
# weight_backup = self.weight.clone()
# Initialize A, B
S = S / self.scaling
self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous()
self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous()
# Initialize weight
# To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :]
self.weight.data = (
((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype)
)
self.lora_A.requires_grad = True
self.lora_B.requires_grad = True
else:
raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}")
def train(self, mode: bool = True): def train(self, mode: bool = True):
""" """
This function runs when model.train() is invoked. It is used to prepare the linear layer for training This function runs when model.train() is invoked. It is used to prepare the linear layer for training
""" """
def T(w):
return w.T if self.fan_in_fan_out else w
self.training = mode self.training = mode
if LORA_MANAGER.merge_weights:
if mode and self.merged: if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.") warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.") raise NotImplementedError("LoRA unmerge is not tested.")
# Make sure that the weights are not merged elif not mode and not self.merged and lora_manager.able_to_merge:
if self.r > 0:
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
# FIXME(csric): temporary fix
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
self.reset_parameters()
else:
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
self.merged = False
elif not mode and not self.merged:
warnings.warn("Invoke module.eval() would merge LoRA weights.") warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it # Merge the weights and mark it
if self.r > 0: if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.weight.data += self.lora_B @ self.lora_A * self.scaling
delattr(self, "lora_A") delattr(self, "lora_A")
delattr(self, "lora_B") delattr(self, "lora_B")
self.merged = True self.merged = True
return self return self
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
class LoraLinear(LoraBase):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
weight: nn.Parameter,
bias: Union[nn.Parameter, bool],
r: int = 0,
lora_alpha: int = 32,
lora_dropout: float = 0.0,
lora_initialization_method: str = "kaiming_uniform",
):
super().__init__(
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
)
self.weight = weight
self.bias = bias
if bias is True:
self.bias = nn.Parameter(torch.zeros(weight.shape[0]))
if bias is not None:
self.bias.requires_grad = True
out_features, in_features = weight.shape
self.in_features = in_features
self.out_features = out_features
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
self.lora_initialization_method = lora_initialization_method
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged: if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias) result = F.linear(x, self.weight, bias=self.bias)
if self.r > 0:
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
return result return result
else: else:
return F.linear(x, T(self.weight), bias=self.bias) return F.linear(x, self.weight, bias=self.bias)
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: class LoraEmbedding(LoraBase):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
weight: nn.Parameter,
r: int = 0,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
num_embeddings: int = None,
embedding_dim: int = None,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
lora_initialization_method: str = "kaiming_uniform",
):
super().__init__(
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
)
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = weight
in_features, out_features = num_embeddings, embedding_dim
self.in_features = in_features
self.out_features = out_features
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
self.lora_initialization_method = lora_initialization_method
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# reset parameters
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def _embed(self, x: torch.Tensor, weight) -> torch.Tensor:
return F.embedding(
x,
weight,
padding_idx=self.padding_idx,
max_norm=self.max_norm,
norm_type=self.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq,
sparse=self.sparse,
)
def forward(self, x: torch.Tensor):
base_embedding = self._embed(x, self.weight)
# base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing
if self.r > 0 and not self.merged:
lora_A_embedding = self._embed(x, self.lora_A.t())
embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling
return embedding
else:
return base_embedding
def train(self, mode: bool = True):
"""
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
"""
self.training = mode
if mode and self.merged:
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
raise NotImplementedError("LoRA unmerge is not tested.")
elif not mode and not self.merged and lora_manager.able_to_merge:
warnings.warn("Invoke module.eval() would merge LoRA weights.")
# Merge the weights and mark it
if self.r > 0:
self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
return self
def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear:
""" """
Wraps a linear layer with LoRA functionality. Wraps a linear layer with LoRA functionality.
Args: Args:
linear (nn.Linear): The linear layer to be wrapped. linear (nn.Linear): The linear layer to be wrapped.
lora_rank (int): The rank of the LoRA decomposition. lora_rank (int): The rank of the LoRA decomposition.
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns: Returns:
LoraLinear: The wrapped linear layer with LoRA functionality. LoraLinear: The wrapped linear layer with LoRA functionality.
""" """
assert ( assert (
lora_rank <= linear.in_features lora_config.r <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" ), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank) bias = None
if lora_config.lora_train_bias in ["all", "lora"]:
bias = linear.bias
if bias is None:
bias = True
lora_linear = LoraLinear(
linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method
)
return lora_linear return lora_linear
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None:
""" """
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form. Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
Args: Args:
module (nn.Module): The module to convert to LoRA form. module (nn.Module): The module to convert to LoRA form.
lora_rank (int): The rank of the LoRA approximation. lora_rank (int): The rank of the LoRA approximation.
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
parent_name (str): The name of the parent module.
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns: Returns:
None None
""" """
for name, child in module.named_children(): for name, child in module.named_children():
if isinstance(child, nn.Linear): if isinstance(child, nn.Linear):
setattr(module, name, _lora_linear_wrapper(child, lora_rank)) if lora_config.target_modules is None or any(
[name in target_module for target_module in lora_config.target_modules]
):
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(f"Converting {parent_name}.{name} to LoRA")
setattr(module, name, _lora_linear_wrapper(child, lora_config))
elif isinstance(child, nn.Embedding):
if lora_config.target_modules is None or any(
[name in target_module for target_module in lora_config.target_modules]
):
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(f"Converting {parent_name}.{name} to LoRA")
setattr(
module,
name,
LoraEmbedding(
child.weight,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.embedding_lora_dropout,
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
lora_initialization_method=lora_config.lora_initialization_method,
),
)
else: else:
_convert_to_lora_recursively(child, lora_rank) _convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module: def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module. """Convert a torch.nn.Module to a LoRA module.
Args: Args:
module (nn.Module): The module to convert. module (nn.Module): The module to convert.
lora_rank (int): LoRA rank. lora_rank (int): LoRA rank.
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
Returns: Returns:
nn.Module: The converted module. nn.Module: The converted module.
""" """
if lora_rank <= 0: if lora_config.r <= 0:
return module return module
_convert_to_lora_recursively(module, lora_rank) # make all parameter not trainable, if lora_train_bias is "all", set bias to trainable
lora.mark_only_lora_as_trainable(module, lora_train_bias) total_parameter_size = 0
for name, p in module.named_parameters():
p.requires_grad = False
if "bias" in name and lora_config.lora_train_bias == "all":
p.requires_grad = True
total_parameter_size += p.numel()
_convert_to_lora_recursively(module, "", lora_config)
trainable_parameter_size = 0
for name, p in module.named_parameters():
if p.requires_grad == True:
trainable_parameter_size += p.numel()
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(
f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%"
)
return module return module

View File

@ -9,6 +9,7 @@
- [Install Requirements](#install-requirements) - [Install Requirements](#install-requirements)
- [Get Start with ColossalRun](#get-start-with-colossalrun) - [Get Start with ColossalRun](#get-start-with-colossalrun)
- [Training Configuration](#training-configuration) - [Training Configuration](#training-configuration)
- [Parameter Efficient Finetuning (PEFT)](#parameter-efficient-finetuning-peft)
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning) - [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
- [Step 1: Data Collection](#step-1-data-collection) - [Step 1: Data Collection](#step-1-data-collection)
- [Step 2: Preprocessing](#step-2-preprocessing) - [Step 2: Preprocessing](#step-2-preprocessing)
@ -377,35 +378,6 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
</details> </details>
<details><summary><b>Low Rank Adaption</b></summary>
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--lora_rank 32 \ # This enables LoRA
--use_wandb
```
</details>
<details><summary><b>Other Training Arguments</b></summary> <details><summary><b>Other Training Arguments</b></summary>
@ -430,6 +402,60 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- use_wandb: if this flag is up, you can view logs on wandb. - use_wandb: if this flag is up, you can view logs on wandb.
</details>
### Parameter Efficient Finetuning (PEFT)
Currently, we have support LoRA (low-rank adaptation) and PiSSA (principal singular values and singular vectors adaptation). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance.
<details><summary><b>Low Rank Adaption and PiSSA</b></summary>
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). Details about Principal Singular Values and Singular Vectors Adaptation (PiSSA) can be found in the paper: [PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models](https://arxiv.org/abs/2404.02948). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. It is suitable for training LLM with constrained resources.
To use LoRA/PiSSA in training, please create a config file as in the following example and set the `--lora_config` to that configuration file.
```json
{
"r": 128,
"embedding_lora_dropout": 0.0,
"linear_lora_dropout": 0.1,
"lora_alpha": 32,
"lora_train_bias": "all",
"lora_initialization_method": "PiSSA",
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
}
```
#### Lora Parameters
- r: lora rank
- embedding_lora_dropout: dropout probability for embedding layer
- linear_lora_dropout: dropout probability for linear layer
- lora_alpha: lora alpha, controls how much the adaptor can deviate from the pretrained model.
- lora_train_bias: whether to add trainable bias to lora layers, choose from "all" (all layers (including but not limited to lora layers) will have trainable biases), "none" (no trainable biases), "lora" (only lora layers will have trainable biases)
- lora_initialization_method: how to initialize lora weights, choose one from ["kaiming_uniform", "PiSSA"], default to "kaiming_uniform". Use "kaiming_uniform" for standard LoRA and "PiSSA" for PiSSA.
- target_modules: which module(s) should be converted to lora layers, if the module's name contain the keywords in target modules and the module is a linear or embedding layer, the module will be converted. Otherwise, the module will be frozen. Setting this field to None will automatically convert all linear and embedding layer to their LoRA counterparts. Note that this example only works for LLaMA, for other models, you need to modify it.
```
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--save_interval 5000 \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--plugin zero2_cpu \
--batch_size 4 \
--max_epochs 1 \
--accumulation_steps 4 \
--lr 2e-5 \
--max_len 2048 \
--lora_config /PATH/TO/THE/LORA/CONFIG/FILE.json \ # Setting this enables LoRA
--use_wandb
```
</details> </details>

View File

@ -0,0 +1,9 @@
{
"r": 128,
"embedding_lora_dropout": 0.0,
"linear_lora_dropout": 0.1,
"lora_alpha": 32,
"lora_train_bias": "all",
"lora_initialization_method": "PiSSA",
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
}

View File

@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module, disable_dropout from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import DPOTrainer from coati.trainer import DPOTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -23,8 +23,11 @@ logger = get_dist_logger()
def train(args): def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1: if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -115,7 +118,7 @@ def train(args):
coordinator.print_on_master(msg="Flash-attention enabled successfully") coordinator.print_on_master(msg="Flash-attention enabled successfully")
else: else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain) model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(model)
if not args.disable_reference_model: if not args.disable_reference_model:
if args.use_flash_attn: if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained( ref_model = AutoModelForCausalLM.from_pretrained(
@ -125,15 +128,19 @@ def train(args):
) )
else: else:
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(ref_model)
else: else:
ref_model = None ref_model = None
if args.lora_rank > 0: if args.lora_config is not None:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
disable_dropout(model)
disable_dropout(ref_model)
if args.grad_checkpoint: if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing # Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer # configure tokenizer
@ -280,11 +287,8 @@ def train(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights: if lora_config is not None and lora_config.r > 0:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights # NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval() model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
if args.save_dir is not None: if args.save_dir is not None:
@ -343,15 +347,8 @@ if __name__ == "__main__":
help="Disable the reference model (enabled by default)", help="Disable the reference model (enabled by default)",
) )
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--log_dir", default=None, type=str)

View File

@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch import torch
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module, disable_dropout from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import KTOTrainer from coati.trainer import KTOTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -23,8 +23,11 @@ logger = get_dist_logger()
def train(args): def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1: if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -115,7 +118,7 @@ def train(args):
coordinator.print_on_master(msg="Flash-attention enabled successfully") coordinator.print_on_master(msg="Flash-attention enabled successfully")
else: else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain) model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(model)
if args.use_flash_attn: if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained( ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, args.pretrain,
@ -124,13 +127,17 @@ def train(args):
) )
else: else:
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain) ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
disable_dropout(ref_model) disable_dropout(ref_model)
if args.lora_rank > 0: disable_dropout(model)
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint: if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing # Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer # configure tokenizer
@ -299,11 +306,8 @@ def train(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights: if lora_config is not None and lora_config.r > 0:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights # NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval() model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
if args.save_dir is not None: if args.save_dir is not None:
@ -355,15 +359,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--auto_weight", default=False, action="store_true") parser.add_argument("--auto_weight", default=False, action="store_true")
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)

View File

@ -6,7 +6,7 @@ from contextlib import nullcontext
import torch import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module, disable_dropout from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
from coati.trainer import ORPOTrainer from coati.trainer import ORPOTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -23,8 +23,11 @@ logger = get_dist_logger()
def train(args): def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1: if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -114,13 +117,16 @@ def train(args):
coordinator.print_on_master(msg="Flash-attention enabled successfully") coordinator.print_on_master(msg="Flash-attention enabled successfully")
else: else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain) model = AutoModelForCausalLM.from_pretrained(args.pretrain)
if args.lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
disable_dropout(model) disable_dropout(model)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint: if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing # Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer # configure tokenizer
@ -262,11 +268,8 @@ def train(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights: if lora_config is not None and lora_config.r > 0:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights # NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval() model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
if args.save_dir is not None: if args.save_dir is not None:
@ -322,15 +325,8 @@ if __name__ == "__main__":
help="Disable the reference model (enabled by default)", help="Disable the reference model (enabled by default)",
) )
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--log_dir", default=None, type=str)

View File

@ -13,7 +13,7 @@ from coati.dataset import (
load_tokenized_dataset, load_tokenized_dataset,
setup_conversation_template, setup_conversation_template,
) )
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.trainer import PPOTrainer from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -31,8 +31,11 @@ logger = get_dist_logger()
def train(args): def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1: if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -81,20 +84,26 @@ def train(args):
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True) ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
reward_model = RewardModel(args.rm_pretrain) reward_model = RewardModel(args.rm_pretrain)
critic = Critic(args.rm_pretrain) critic = Critic(args.rm_pretrain)
if args.lora_config is not None:
actor = convert_to_lora_module(actor, lora_config=lora_config)
critic = convert_to_lora_module(critic, lora_config=lora_config)
for name, module in actor.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
for name, module in critic.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
lora_manager.able_to_merge = False
# Disable dropout # Disable dropout
disable_dropout(actor) disable_dropout(actor)
disable_dropout(critic) disable_dropout(critic)
if args.lora_rank > 0: if args.grad_checkpoint:
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias) actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias) critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if args.grad_checkpoint and args.lora_rank == 0:
actor.gradient_checkpointing_enable()
critic.model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer # configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@ -421,11 +430,9 @@ def train(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights: if lora_config is not None and lora_config.r > 0:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights # NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True lora_manager.able_to_merge = True
actor.eval() actor.eval()
critic.eval() critic.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
@ -484,11 +491,9 @@ if __name__ == "__main__":
parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--experience_batch_size", type=int, default=16) parser.add_argument("--experience_batch_size", type=int, default=16)
parser.add_argument("--ptx_batch_size", type=int, default=4) parser.add_argument("--ptx_batch_size", type=int, default=4)
parser.add_argument("--lora_train_bias", type=str, default="none") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=9e-6) parser.add_argument("--lr", type=float, default=9e-6)
parser.add_argument("--critic_lr", type=float, default=9e-6) parser.add_argument("--critic_lr", type=float, default=9e-6)
parser.add_argument("--kl_coef", type=float, default=0.1) parser.add_argument("--kl_coef", type=float, default=0.1)

View File

@ -7,7 +7,7 @@ from contextlib import nullcontext
import torch import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -25,8 +25,11 @@ logger = get_dist_logger()
def train(args): def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1: if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -58,9 +61,11 @@ def train(args):
args.pretrain, args.pretrain,
) )
if args.lora_rank > 0: if lora_config is not None:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
@ -122,11 +127,9 @@ def train(args):
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
if args.grad_checkpoint and args.lora_rank == 0: if args.grad_checkpoint:
model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer # configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@ -272,16 +275,13 @@ def train(args):
trainer.fit( trainer.fit(
train_preference_dataloader=train_dataloader, train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None, eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir, log_dir=args.log_dir,
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights: if lora_config is not None and lora_config.r > 0:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights # NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval() model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
if args.save_dir is not None: if args.save_dir is not None:
@ -330,15 +330,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function") parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str) parser.add_argument("--log_dir", default=None, type=str)

View File

@ -7,7 +7,7 @@ from contextlib import nullcontext
import torch import torch
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import convert_to_lora_module from coati.models import LoraConfig, convert_to_lora_module
from coati.trainer import SFTTrainer from coati.trainer import SFTTrainer
from coati.utils import load_checkpoint from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@ -24,8 +24,11 @@ logger = get_dist_logger()
def train(args): def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility # check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0: if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin") raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1: if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin") raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@ -53,8 +56,12 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True, trust_remote_code=True,
) )
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) if lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
if args.plugin == "ddp": if args.plugin == "ddp":
""" """
@ -114,6 +121,15 @@ def train(args):
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# ====================================================== # ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler # Initialize Model, Objective, Optimizer and LR Scheduler
# ====================================================== # ======================================================
@ -124,7 +140,7 @@ def train(args):
if args.grad_checkpoint: if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing # Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer # configure tokenizer
@ -149,15 +165,6 @@ def train(args):
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}") coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}") coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
# configure optimizer
optim = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# configure dataset # configure dataset
coordinator.print_on_master( coordinator.print_on_master(
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
@ -217,6 +224,7 @@ def train(args):
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=train_dataloader, dataloader=train_dataloader,
) )
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB") coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@ -277,11 +285,8 @@ def train(args):
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
if args.lora_rank > 0 and args.merge_lora_weights: if lora_config is not None and lora_config.r > 0:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights # NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval() model.eval()
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting on only rank0
if args.save_path is not None: if args.save_path is not None:
@ -328,15 +333,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512) parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision") parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints") parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6) parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--config_file", type=str, default=None, help="Config file") parser.add_argument("--config_file", type=str, default=None, help="Config file")
parser.add_argument("--accumulation_steps", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=8)

View File

@ -21,16 +21,16 @@ PARENT_LOG_DIR="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
declare -a dataset=( declare -a dataset=(
/Your/SFT/Data/arrow/part-00000 YOUR/SFT/DATA/DIR/arrow/part-00000
/Your/SFT/Data/arrow/part-00001 YOUR/SFT/DATA/DIR/arrow/part-00001
/Your/SFT/Data/arrow/part-00002 YOUR/SFT/DATA/DIR/arrow/part-00002
/Your/SFT/Data/arrow/part-00003 YOUR/SFT/DATA/DIR/arrow/part-00003
/Your/SFT/Data/arrow/part-00004 YOUR/SFT/DATA/DIR/arrow/part-00004
/Your/SFT/Data/arrow/part-00005 YOUR/SFT/DATA/DIR/arrow/part-00005
/Your/SFT/Data/arrow/part-00006 YOUR/SFT/DATA/DIR/arrow/part-00006
/Your/SFT/Data/arrow/part-00007 YOUR/SFT/DATA/DIR/arrow/part-00007
/Your/SFT/Data/arrow/part-00008 YOUR/SFT/DATA/DIR/arrow/part-00008
/Your/SFT/Data/arrow/part-00009 YOUR/SFT/DATA/DIR/arrow/part-00009
) )
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S) TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
@ -47,15 +47,14 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \ --tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--save_interval 2000 \ --save_interval 2000 \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \
--lora_rank 0 \
--plugin zero2 \ --plugin zero2 \
--batch_size 8 \ --batch_size 8 \
--max_epochs 1 \ --max_epochs 1 \
--accumulation_steps 2 \ --accumulation_steps 1 \
--lr 5e-5 \ --lr 5e-5 \
--max_len 4096 \ --max_len 4096 \
--use_flash_attn \
--grad_checkpoint \ --grad_checkpoint \
--use_flash_attn --save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--log_dir $LOG_DIR \

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from coati.models import convert_to_lora_module from coati.models import convert_to_lora_module
from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader, TensorDataset
@ -38,7 +39,7 @@ def test_overfit():
# Build and convert model # Build and convert model
model = SimpleNN(input_size, hidden_size, num_classes) model = SimpleNN(input_size, hidden_size, num_classes)
weight_to_compare = model.fc1.weight.detach().clone() weight_to_compare = model.fc1.weight.detach().clone()
model = convert_to_lora_module(model, lora_rank=30) model = convert_to_lora_module(model, lora_config=LoraConfig(r=32))
# Loss and optimizer # Loss and optimizer
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
@ -50,7 +51,6 @@ def test_overfit():
# Forward pass # Forward pass
outputs = model(inputs) outputs = model(inputs)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
print(loss)
# Backward and optimize # Backward and optimize
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
@ -65,5 +65,50 @@ def test_overfit():
assert (weight_to_compare - model.fc1.weight).sum() < 0.01 assert (weight_to_compare - model.fc1.weight).sum() < 0.01
def test_lora_linear_accuracy():
weight = torch.randn(10, 5)
linear = nn.Linear(5, 10)
linear.weight.data = weight
x = torch.randn(10, 5)
out_linear = linear(x)
# lora linear Pissa
linear.weight.data = weight
lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA")
out_lora = lora_linear(x)
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
# lora linear
linear.weight.data = weight
lora_linear = LoraLinear(linear.weight, linear.bias, r=2)
out_lora = lora_linear(x)
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
def test_lora_embedding_accuracy():
weight = torch.randn(10, 5)
embedding = nn.Embedding(10, 5)
embedding.weight.data = weight
x = torch.randint(0, 10, (10,))
out_embedding = embedding(x)
# lora embedding Pissa
embedding.weight.data = weight
lora_embedding = LoraEmbedding(
embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5
)
out_lora = lora_embedding(x)
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
# lora embedding
embedding.weight.data = weight
lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5)
out_lora = lora_embedding(x)
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
if __name__ == "__main__": if __name__ == "__main__":
test_overfit() test_overfit()
test_lora_linear_accuracy()
test_lora_embedding_accuracy()

View File

@ -30,9 +30,10 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout # Skip those tests due to CI tests timeout
MODELS=('llama') MODELS=('llama')
ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
export OMP_NUM_THREADS=8 export OMP_NUM_THREADS=8
@ -112,6 +113,11 @@ for lora_rank in ${LORA_RANK[@]}; do
sp='1' sp='1'
sp_mode='split_gather' sp_mode='split_gather'
enable_sequence_parallelism='' enable_sequence_parallelism=''
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='4'
bs='8' bs='8'
@ -176,7 +182,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \ --eval_dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \ --save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ $lora_config \
--plugin $plugin \ --plugin $plugin \
--batch_size $bs \ --batch_size $bs \
--max_epochs 1 \ --max_epochs 1 \
@ -230,6 +236,11 @@ for lora_rank in ${LORA_RANK[@]}; do
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}") grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1' tp='1'
bs='2' bs='2'
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='4'
bs='8' bs='8'
@ -252,7 +263,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \ --eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \ --save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ $lora_config \
--plugin $plugin \ --plugin $plugin \
--batch_size $bs \ --batch_size $bs \
--max_epochs 1 \ --max_epochs 1 \
@ -308,6 +319,11 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='4' bs='4'
ebs='8' ebs='8'
conversation_template=$(get_conversation_template_config $model) conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then if [[ $plugin == "3d" ]]; then
tp='4' tp='4'
bs='16' bs='16'
@ -344,7 +360,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--ptx_batch_size 1 \ --ptx_batch_size 1 \
--ptx_coef 0.2 \ --ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \ --save_path $MODEL_SAVE_PATH \
--lora_rank $lora_rank \ $lora_config \
--plugin $plugin \ --plugin $plugin \
--num_episodes 5 \ --num_episodes 5 \
--num_collect_steps 1 \ --num_collect_steps 1 \
@ -404,6 +420,11 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='4' tp='4'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2' grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation # gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then if [[ $plugin == "gemini_auto" ]]; then
@ -428,7 +449,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \ --eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \ --save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ $lora_config \
--plugin $plugin \ --plugin $plugin \
--batch_size $bs \ --batch_size $bs \
--max_epochs 1 \ --max_epochs 1 \
@ -482,6 +503,11 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='4' tp='4'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2' grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation # gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then if [[ $plugin == "gemini_auto" ]]; then
@ -506,7 +532,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \ --eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \ --save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ $lora_config \
--plugin $plugin \ --plugin $plugin \
--batch_size $bs \ --batch_size $bs \
--max_epochs 1 \ --max_epochs 1 \
@ -560,6 +586,11 @@ for lora_rank in ${LORA_RANK[@]}; do
tp='4' tp='4'
bs='8' bs='8'
fi fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
grad_accu='2' grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation # gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then if [[ $plugin == "gemini_auto" ]]; then
@ -584,7 +615,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--eval_dataset ${dataset[@]} \ --eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \ --save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ $lora_config \
--plugin $plugin \ --plugin $plugin \
--batch_size $bs \ --batch_size $bs \
--max_epochs 1 \ --max_epochs 1 \