[LowLevelZero] low level zero support lora (#5153)

* low level zero support lora

low level zero support lora

* add checkpoint test

* add checkpoint test

* fix

* fix

* fix

* fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* test ci

* git # This is a combination of 3 commits.

Update low_level_zero_plugin.py

Update low_level_zero_plugin.py

fix

fix

fix

* fix naming

fix naming

fix naming

fix
pull/5001/merge
flybird11111 11 months ago committed by GitHub
parent c5fd4aa6e8
commit cabc1286ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,12 +1,15 @@
import logging
import warnings
import enum
import os
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Dict, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
@ -41,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
class OptimizerParamCheckState(enum.Enum):
ORIGIN_PARAM_FINDED = 0
ORIGIN_PARAM_NOT_FIND = -1
LORA_PARM_EXISTED = -2
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
@ -208,6 +216,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
class LowLevelZeroPlugin(DPPluginBase):
"""
@ -287,6 +307,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.lora_enabled = False
self.verbose = verbose
# set class name with stage, for better error message
@ -310,6 +331,66 @@ class LowLevelZeroPlugin(DPPluginBase):
def supported_devices(self) -> List[str]:
return ["cuda"]
def support_lora(self) -> bool:
return True
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
from peft import PeftModel, get_peft_model
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
origin_param_id = id(origin_param)
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']:
if id(p) == origin_param_id:
return group_id
return -1
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
origin_param_id = id(origin_param)
lora_param_id = id(lora_param)
target_group_id = None
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group['params']:
if id(p) == lora_param_id:
# check if the lora parameter exists.
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
if id(p) == origin_param_id:
target_group_id = group_id
if target_group_id is not None:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
else:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
def add_lora_params_to_optimizer(self, model, optimizer):
""" add lora parameters to optimizer """
name2param= {}
for name, param in model.named_parameters():
name2param[name] = param
for name, param in name2param.items():
if 'lora_A' in name or 'lora_B' in name:
origin_key = name.replace("lora_A.", "")
origin_key = origin_key.replace("lora_B.", "")
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
optimizer.param_groups[group_id]['params'].append(param)
def configure(
self,
model: nn.Module,
@ -318,6 +399,13 @@ class LowLevelZeroPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if self.lora_enabled:
from peft import PeftModel
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
if optimizer is not None:
self.add_lora_params_to_optimizer(model, optimizer)
if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
@ -339,8 +427,3 @@ class LowLevelZeroPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.no_sync()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError

@ -44,6 +44,20 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle
def check_for_nccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return (
c10d.is_nccl_available() and
pg.name() == c10d.Backend.NCCL
)
def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
@ -65,7 +79,7 @@ def _broadcast_object_list(
c10d._warn_not_in_group("broadcast_object_list")
return
is_nccl_backend = c10d._check_for_nccl_backend(group)
is_nccl_backend = check_for_nccl_backend(group)
current_device = None
if device is not None:

@ -82,6 +82,9 @@ class GradientStore(BaseStore):
"""
grad_list = []
# When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients.
if group_id not in self._grads_of_params.keys():
return grad_list
for param_grads in self._grads_of_params[group_id].values():
grad_list.append(param_grads[self._working_index])

@ -18,5 +18,5 @@ SentencePiece
ninja
flash_attn==2.0.5
datasets
peft
peft>=0.7.1
#auto-gptq now not support torch1.12

@ -14,3 +14,4 @@ einops
sentencepiece
google
protobuf
peft>=0.7.1

@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union, Dict
import torch
import torch.distributed as dist
@ -51,6 +51,12 @@ class DPPluginWrapper(DPPluginBase):
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
pass
def support_lora(self) -> bool:
pass
def check_dataloader_sharding():
plugin = DPPluginWrapper()

@ -2,6 +2,7 @@ from typing import Optional
import torch
import torch.distributed as dist
from peft import LoraConfig
import colossalai
from colossalai.booster import Booster
@ -18,12 +19,16 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
if lora_config is not None:
model = booster.enable_lora(model, lora_config=lora_config)
criterion = lambda x: x.mean()
data = data_gen_fn()
@ -43,6 +48,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
except Exception as e:
return repr(e)
# raise e
@parameterize("stage", [2])
@ -81,10 +88,41 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
@parameterize("stage", [2])
@parameterize("model_name", ["transformers_llama"])
def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair
sub_model_zoo = model_zoo.get_sub_registry(model_name)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break
if dist.get_rank() == 0:
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_plugin(early_stop=early_stop)
check_low_level_zero_lora(early_stop=early_stop)
@rerun_if_address_is_in_use()

@ -2,6 +2,9 @@ import torch
import torch.distributed as dist
from torchvision.models import resnet18
from utils import shared_tempdir
from typing import Optional
from peft import LoraConfig
from copy import deepcopy
import colossalai
from colossalai.booster import Booster
@ -15,6 +18,7 @@ from colossalai.testing import (
spawn,
)
from colossalai.zero import LowLevelZeroOptimizer
from tests.kit.model_zoo import model_zoo
# stage 1 and 2 process the optimizer/mode the same way
@ -69,9 +73,103 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
torch.cuda.empty_cache()
def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
new_plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
booster = Booster(plugin=plugin)
new_booster = Booster(plugin=new_plugin)
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
new_model = deepcopy(model)
new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3)
model = booster.enable_lora(model, lora_config=lora_config)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_lora_as_pretrained(model, model_ckpt_path)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
assert p_id in working_param_id_set
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
padding = new_optimizer._param_store.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
)
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
except Exception as e:
# return repr(e)
raise e
@clear_cache_before_run()
@parameterize("stage", [2])
@parameterize("shard", [True, False])
@parameterize("offload", [False, True])
@parameterize("model_name", ["transformers_llama"])
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair
sub_model_zoo = model_zoo.get_sub_registry(model_name)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name != "transformers_llama":
continue
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
err = run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
else:
failed_info[name] = err
if early_stop:
break
if dist.get_rank() == 0:
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
def run_dist(rank, world_size, port):
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
check_low_level_zero_checkpointIO()
check_low_level_zero_lora_checkpointIO()
torch.cuda.empty_cache()

Loading…
Cancel
Save