mirror of https://github.com/hpcaitech/ColossalAI
[lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin * fix * fix * fix * fixcolossalchat
parent
19d1510ea2
commit
75c963686f
|
@ -30,6 +30,7 @@ from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
@ -1187,7 +1188,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def support_lora(self) -> bool:
|
def support_lora(self) -> bool:
|
||||||
return False
|
return True
|
||||||
|
|
||||||
def control_checkpoint_io(self) -> bool:
|
def control_checkpoint_io(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -1415,6 +1416,24 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
||||||
def enable_lora(
|
def enable_lora(
|
||||||
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
self,
|
||||||
|
model: Module,
|
||||||
|
pretrained_dir: Optional[str] = None,
|
||||||
|
lora_config: Optional[Dict] = None,
|
||||||
|
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||||
) -> Module:
|
) -> Module:
|
||||||
raise NotImplementedError
|
from peft import PeftModel, get_peft_model
|
||||||
|
|
||||||
|
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
||||||
|
assert self.pp_size == 1 and self.tp_size == 1
|
||||||
|
self.lora_enabled = True
|
||||||
|
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||||
|
|
||||||
|
if bnb_quantization_config is not None:
|
||||||
|
model = quantize_model(model, bnb_quantization_config)
|
||||||
|
|
||||||
|
if pretrained_dir is None:
|
||||||
|
peft_model = get_peft_model(model, lora_config)
|
||||||
|
else:
|
||||||
|
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||||
|
return peft_model
|
||||||
|
|
|
@ -947,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
state_[k] = v.detach().clone().to(device)
|
state_[k] = v.detach().clone().to(device)
|
||||||
|
|
||||||
return state_
|
return state_
|
||||||
|
|
||||||
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||||
|
if os.path.isfile(checkpoint):
|
||||||
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
|
return
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
|
model._force_wait_all_gather()
|
||||||
|
peft_model = model.unwrap()
|
||||||
|
assert isinstance(
|
||||||
|
peft_model, PeftModel
|
||||||
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
|
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
||||||
|
|
|
@ -243,6 +243,9 @@ def _fullname(obj):
|
||||||
# patch custom models which are not in transformers
|
# patch custom models which are not in transformers
|
||||||
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
|
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
|
||||||
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
|
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
|
||||||
|
if module.startswith("peft"):
|
||||||
|
klass = obj.base_model.model.__class__
|
||||||
|
module = klass.__module__
|
||||||
if module.startswith("transformers_modules"):
|
if module.startswith("transformers_modules"):
|
||||||
split_module = module.split(".")
|
split_module = module.split(".")
|
||||||
if len(split_module) >= 2:
|
if len(split_module) >= 2:
|
||||||
|
|
|
@ -9,7 +9,8 @@ from torch.optim import AdamW
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
|
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
|
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||||
|
@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||||
|
|
||||||
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
|
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
|
||||||
test_configs = [
|
test_configs = [
|
||||||
{
|
{
|
||||||
"lora_config": lora_config,
|
"lora_config": lora_config,
|
||||||
|
@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
|
||||||
|
|
||||||
# test fwd bwd correctness
|
# test fwd bwd correctness
|
||||||
test_model = model_load
|
test_model = model_load
|
||||||
|
if isinstance(model_load, HybridParallelModule):
|
||||||
|
model_load = model_load.module.module
|
||||||
model_copy = copy.deepcopy(model_load)
|
model_copy = copy.deepcopy(model_load)
|
||||||
|
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
Loading…
Reference in New Issue