From 14b0d4c7e5340b475d75319a43bbdb77b7fcc7a5 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 31 Oct 2023 15:19:37 +0800 Subject: [PATCH] [lora] add lora APIs for booster, support lora for TorchDDP (#4981) * add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements --- colossalai/booster/booster.py | 57 +++++++++ colossalai/booster/plugin/gemini_plugin.py | 10 +- .../booster/plugin/hybrid_parallel_plugin.py | 10 +- .../booster/plugin/low_level_zero_plugin.py | 10 +- colossalai/booster/plugin/plugin_base.py | 12 +- colossalai/booster/plugin/torch_ddp_plugin.py | 32 +++++- .../booster/plugin/torch_fsdp_plugin.py | 10 +- .../checkpoint_io/checkpoint_io_base.py | 17 +++ .../checkpoint_io/general_checkpoint_io.py | 3 + requirements/requirements-test.txt | 3 +- tests/test_lora/test_torch_ddp_lora.py | 108 ++++++++++++++++++ 11 files changed, 265 insertions(+), 7 deletions(-) create mode 100644 tests/test_lora/test_torch_ddp_lora.py diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index d73bc5bab..c2a724084 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -8,6 +8,14 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +SUPPORT_PEFT = False +try: + import peft + + SUPPORT_PEFT = True +except ImportError: + pass + import colossalai.interface.pretrained as pretrained_utils from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper @@ -221,6 +229,38 @@ class Booster: assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync." return self.plugin.no_sync(model, optimizer) + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None + ) -> nn.Module: + """ + Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. + Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft. + + Args: + model (nn.Module): The model to be appended with LoRA modules. + pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory + or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub. + When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None. + lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None. + """ + if not SUPPORT_PEFT: + raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!") + + assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided." + assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora." + if pretrained_dir is None: + assert ( + lora_config is not None + ), "Please provide configuration for Lora when pretrained directory path isn't passed in." + assert isinstance( + lora_config, peft.LoraConfig + ), "The passed in configuration should be an instance of peft.LoraConfig." + if lora_config is None: + assert ( + pretrained_dir is not None + ), "Please provide pretrained directory path if not passing in lora configuration." + return self.plugin.enable_lora(model, pretrained_dir, lora_config) + def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: """Load model from checkpoint. @@ -323,3 +363,20 @@ class Booster: checkpoint (str): Path to the checkpoint. It must be a local file path. """ self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint) + + def save_lora_as_pretrained( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + ) -> None: + """ + Save the lora adapters and adapter configuration file to a pretrained checkpoint directory. + + Args: + model (Union[nn.Module, ModelWrapper]): A model boosted by Booster. + checkpoint (str): Path to the checkpoint directory. It must be a local path. + use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False. + """ + if not SUPPORT_PEFT: + raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!") + assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided." + assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora." + self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index a67ca18a3..964cd302a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -3,7 +3,7 @@ import logging import os import random from pathlib import Path -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple import numpy as np import torch @@ -444,6 +444,9 @@ class GeminiPlugin(DPPluginBase): def support_no_sync(self) -> bool: return False + def support_lora(self) -> bool: + return False + def control_precision(self) -> bool: return True @@ -573,3 +576,8 @@ class GeminiPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 5237734f0..97057481e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -4,7 +4,7 @@ import warnings from contextlib import contextmanager from functools import partial from types import MethodType -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch @@ -1156,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase): def support_no_sync(self) -> bool: return True + def support_lora(self) -> bool: + return False + def control_checkpoint_io(self) -> bool: return True @@ -1356,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase): self.zero_stage != 2 ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + + def enable_lora( + self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index d21496f0b..243051895 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -3,7 +3,7 @@ import os from functools import partial from pathlib import Path from types import MethodType -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -296,6 +296,9 @@ class LowLevelZeroPlugin(DPPluginBase): def support_no_sync(self) -> bool: return self.stage == 1 + def support_lora(self) -> bool: + return False + def control_precision(self) -> bool: return True @@ -337,3 +340,8 @@ class LowLevelZeroPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(optimizer, LowLevelZeroOptimizer) return optimizer.no_sync() + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 4e570cbe8..6dc0c560d 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -33,6 +33,10 @@ class Plugin(ABC): def support_no_sync(self) -> bool: pass + @abstractmethod + def support_lora(self) -> bool: + pass + @abstractmethod def configure( self, @@ -63,6 +67,12 @@ class Plugin(ABC): Context manager to disable gradient synchronization. """ + @abstractmethod + def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module: + """ + Add LoRA modules to the model passed in. Should only be called in booster.enable_lora(). + """ + @abstractmethod def prepare_dataloader( self, diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 738634473..9ba520de2 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -116,6 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) + def save_lora_as_pretrained( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + ) -> None: + """ + Save the lora adapters and adapter configuration file to checkpoint directory. + """ + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + if self.coordinator.is_master(): + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) + class TorchDDPModel(ModelWrapper): def __init__(self, module: nn.Module, *args, **kwargs) -> None: @@ -173,6 +189,9 @@ class TorchDDPPlugin(DPPluginBase): def support_no_sync(self) -> bool: return True + def support_lora(self) -> bool: + return True + def control_precision(self) -> bool: return False @@ -216,3 +235,14 @@ class TorchDDPPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin." return model.module.no_sync() + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + from peft import PeftModel, get_peft_model + + assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." + if pretrained_dir is None: + return get_peft_model(model, lora_config) + else: + return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 0aa0caa9a..cd2f9e840 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -2,7 +2,7 @@ import logging import os import warnings from pathlib import Path -from typing import Callable, Iterable, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase): def support_no_sync(self) -> bool: return False + def support_lora(self) -> bool: + return False + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError("Torch fsdp no_sync func not supported yet.") @@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return TorchFSDPCheckpointIO() + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 712324215..949ba4d44 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -335,3 +335,20 @@ class CheckpointIO(ABC): """ state_dict = torch.load(checkpoint) lr_scheduler.load_state_dict(state_dict) + + # ================================================================================ + # Abstract method for lora saving implementation. + # ================================================================================ + + @abstractmethod + def save_lora_as_pretrained( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + ) -> None: + """ + Save the lora adapters and adapter configuration file to a pretrained checkpoint directory. + + Args: + model (Union[nn.Module, ModelWrapper]): A model boosted by Booster. + checkpoint (str): Path to the checkpoint directory. It must be a local path. + use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False. + """ diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a652d9b45..b9253a56d 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO): self.__class__.__name__, "\n\t".join(error_msgs) ) ) + + def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: + raise NotImplementedError diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 0b15b9311..de7fe8a21 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -5,7 +5,7 @@ git+https://github.com/hpcaitech/pytest-testmon torchvision timm titans -torchaudio +torchaudio>=0.13.1 torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes. torchrec==0.2.0 contexttimer @@ -18,4 +18,5 @@ flash_attn datasets pydantic ray +peft #auto-gptq now not support torch1.12 diff --git a/tests/test_lora/test_torch_ddp_lora.py b/tests/test_lora/test_torch_ddp_lora.py new file mode 100644 index 000000000..b3169bf86 --- /dev/null +++ b/tests/test_lora/test_torch_ddp_lora.py @@ -0,0 +1,108 @@ +import copy +import os + +import torch +from peft import LoraConfig +from torch import distributed as dist +from torch.optim import AdamW + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.testing import ( + assert_equal, + assert_not_equal, + check_state_dict_equal, + clear_cache_before_run, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_checkpoint_io.utils import shared_tempdir + + +@clear_cache_before_run() +def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): + model = model_fn() + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = booster.enable_lora(model, lora_config=lora_config) + model_copy = copy.deepcopy(model) + + optimizer = AdamW(model.parameters(), lr=0.001) + criterion = loss_fn + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} + + output = model(**data) + output = output_transform_fn(output) + loss = criterion(output) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()): + if "lora_" in n1: + # lora modules require gradients, thus updated + assert p1.requires_grad + assert_not_equal(p1.to(p2.device), p2) + else: + if not p1.requires_grad: + assert_equal(p1.to(p2.device), p2) + + +@clear_cache_before_run() +def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): + plugin = TorchDDPPlugin() + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + + model_save = model_fn() + model_load = copy.deepcopy(model_save) + + booster = Booster(plugin=plugin) + model_save = booster.enable_lora(model_save, lora_config=lora_config) + model_save, _, _, _, _ = booster.boost(model_save) + + with shared_tempdir() as tempdir: + lora_ckpt_path = os.path.join(tempdir, "ckpt") + booster.save_lora_as_pretrained(model_save, lora_ckpt_path) + dist.barrier() + + # The Lora checkpoint should be small in size + checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) + assert checkpoint_size_mb < 1 + + model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path) + model_load, _, _, _, _ = booster.boost(model_load) + + check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) + + +def run_lora_test(): + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) + check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_lora_test() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_lora(): + spawn(run_dist, 2)