diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 57e445735..b2087af68 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,12 +1,15 @@ import logging +import warnings +import enum import os from functools import partial from pathlib import Path from types import MethodType -from typing import Callable, Dict, Iterator, List, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict import torch import torch.nn as nn +from torch.nn import Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map @@ -41,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] +class OptimizerParamCheckState(enum.Enum): + ORIGIN_PARAM_FINDED = 0 + ORIGIN_PARAM_NOT_FIND = -1 + LORA_PARM_EXISTED = -2 + class LowLevelZeroModel(ModelWrapper, AMPModelMixin): def __init__(self, module: nn.Module, precision: str) -> None: @@ -208,6 +216,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + from peft import PeftModel + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) + class LowLevelZeroPlugin(DPPluginBase): """ @@ -287,6 +307,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload=cpu_offload, master_weights=master_weights, ) + self.lora_enabled = False self.verbose = verbose # set class name with stage, for better error message @@ -310,6 +331,66 @@ class LowLevelZeroPlugin(DPPluginBase): def supported_devices(self) -> List[str]: return ["cuda"] + + def support_lora(self) -> bool: + return True + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + from peft import PeftModel, get_peft_model + assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." + self.lora_enabled = True + warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model + + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter): + origin_param_id = id(origin_param) + for group_id, param_group in enumerate(optimizer.param_groups): + for p in param_group['params']: + if id(p) == origin_param_id: + return group_id + return -1 + + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter): + origin_param_id = id(origin_param) + lora_param_id = id(lora_param) + target_group_id = None + for group_id, param_group in enumerate(optimizer.param_groups): + for p in param_group['params']: + if id(p) == lora_param_id: + # check if the lora parameter exists. + return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED + if id(p) == origin_param_id: + target_group_id = group_id + if target_group_id is not None: + return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED + else: + return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND + + def add_lora_params_to_optimizer(self, model, optimizer): + """ add lora parameters to optimizer """ + name2param= {} + for name, param in model.named_parameters(): + name2param[name] = param + + for name, param in name2param.items(): + if 'lora_A' in name or 'lora_B' in name: + origin_key = name.replace("lora_A.", "") + origin_key = origin_key.replace("lora_B.", "") + origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer") + origin_param = name2param[origin_key] + group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) + if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: + warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.") + elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0: + optimizer.param_groups[group_id]['params'].append(param) + def configure( self, model: nn.Module, @@ -318,6 +399,13 @@ class LowLevelZeroPlugin(DPPluginBase): dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + if self.lora_enabled: + from peft import PeftModel + assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True" + if optimizer is not None: + self.add_lora_params_to_optimizer(model, optimizer) + + if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.precision) @@ -339,8 +427,3 @@ class LowLevelZeroPlugin(DPPluginBase): def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(optimizer, LowLevelZeroOptimizer) return optimizer.no_sync() - - def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None - ) -> nn.Module: - raise NotImplementedError diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index f822c1819..29a102be0 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -44,6 +44,20 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle +def check_for_nccl_backend(group): + + pg = group or c10d._get_default_group() + # Gate PG wrapper check on Gloo availability. + if c10d._GLOO_AVAILABLE: + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, c10d._ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + c10d.is_nccl_available() and + pg.name() == c10d.Backend.NCCL + ) def _broadcast_object_list( object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None @@ -65,7 +79,7 @@ def _broadcast_object_list( c10d._warn_not_in_group("broadcast_object_list") return - is_nccl_backend = c10d._check_for_nccl_backend(group) + is_nccl_backend = check_for_nccl_backend(group) current_device = None if device is not None: diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 1164532fa..631242a43 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -82,6 +82,9 @@ class GradientStore(BaseStore): """ grad_list = [] + # When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients. + if group_id not in self._grads_of_params.keys(): + return grad_list for param_grads in self._grads_of_params[group_id].values(): grad_list.append(param_grads[self._working_index]) diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 5af311770..29a17ce7f 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,5 +18,5 @@ SentencePiece ninja flash_attn==2.0.5 datasets -peft +peft>=0.7.1 #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 19cb7a154..db9c9908c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -14,3 +14,4 @@ einops sentencepiece google protobuf +peft>=0.7.1 \ No newline at end of file diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index 0ac9d0f6d..eabe69ed3 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Tuple, Union +from typing import Callable, Iterator, List, Tuple, Union, Dict import torch import torch.distributed as dist @@ -51,6 +51,12 @@ class DPPluginWrapper(DPPluginBase): def no_sync(self, model: nn.Module) -> Iterator[None]: pass + def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module: + pass + + def support_lora(self) -> bool: + pass + def check_dataloader_sharding(): plugin = DPPluginWrapper() diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 104ca254c..9ad39d089 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -2,6 +2,7 @@ from typing import Optional import torch import torch.distributed as dist +from peft import LoraConfig import colossalai from colossalai.booster import Booster @@ -18,12 +19,16 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] -def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: +def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]: try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) model = model_fn() optimizer = HybridAdam(model.parameters(), lr=1e-3) + + if lora_config is not None: + model = booster.enable_lora(model, lora_config=lora_config) + criterion = lambda x: x.mean() data = data_gen_fn() @@ -43,6 +48,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: except Exception as e: return repr(e) + # raise e + @parameterize("stage", [2]) @@ -81,10 +88,41 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) +@parameterize("stage", [2]) +@parameterize("model_name", ["transformers_llama"]) +def check_low_level_zero_lora(stage, model_name, early_stop: bool = True): + passed_models = [] + failed_info = {} # (model_name, error) pair + + sub_model_zoo = model_zoo.get_sub_registry(model_name) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config) + + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) + def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_plugin(early_stop=early_stop) + check_low_level_zero_lora(early_stop=early_stop) @rerun_if_address_is_in_use() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index e7f44f97e..ed5aa7dbd 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -2,6 +2,9 @@ import torch import torch.distributed as dist from torchvision.models import resnet18 from utils import shared_tempdir +from typing import Optional +from peft import LoraConfig +from copy import deepcopy import colossalai from colossalai.booster import Booster @@ -15,6 +18,7 @@ from colossalai.testing import ( spawn, ) from colossalai.zero import LowLevelZeroOptimizer +from tests.kit.model_zoo import model_zoo # stage 1 and 2 process the optimizer/mode the same way @@ -69,9 +73,103 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): torch.cuda.empty_cache() +def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]: + try: + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload) + new_plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload) + booster = Booster(plugin=plugin) + new_booster = Booster(plugin=new_plugin) + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + new_model = deepcopy(model) + new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3) + model = booster.enable_lora(model, lora_config=lora_config) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_lora_as_pretrained(model, model_ckpt_path) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False) + new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + + # check master weight + assert isinstance(new_optimizer, LowLevelZeroOptimizer) + working_param_id_set = set(id(p) for p in new_model.parameters()) + for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + assert p_id in working_param_id_set + working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] + padding = new_optimizer._param_store.get_param_padding_size(working_param) + padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) + working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] + assert torch.equal( + working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) + ) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) + + except Exception as e: + # return repr(e) + raise e + +@clear_cache_before_run() +@parameterize("stage", [2]) +@parameterize("shard", [True, False]) +@parameterize("offload", [False, True]) +@parameterize("model_name", ["transformers_llama"]) +def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True): + passed_models = [] + failed_info = {} # (model_name, error) pair + + sub_model_zoo = model_zoo.get_sub_registry(model_name) + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name != "transformers_llama": + continue + task_type = None + if name == "transformers_llama_for_casual_lm": + task_type = "CAUSAL_LM" + if name == "transformers_llama_for_sequence_classification": + task_type = "SEQ_CLS" + lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) + err = run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config) + + torch.cuda.empty_cache() + + if err is None: + passed_models.append(name) + else: + failed_info[name] = err + if early_stop: + break + + if dist.get_rank() == 0: + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) + def run_dist(rank, world_size, port): colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_checkpointIO() + check_low_level_zero_lora_checkpointIO() torch.cuda.empty_cache()