import copy import os from itertools import product import torch from peft import LoraConfig from torch import distributed as dist from torch.optim import AdamW import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_checkpoint_io.utils import shared_tempdir @clear_cache_before_run() def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): model = model_fn() lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] test_configs = [ { "lora_config": lora_config, "quantize": False, }, { "lora_config": lora_config, "quantize": True, }, ] for plugin, test_config in product(test_plugins, test_configs): # checkpoint loaded model model_save = model_fn() model_load = copy.deepcopy(model_save) optimizer = AdamW(model.parameters(), lr=0.001) criterion = loss_fn booster = Booster(plugin=plugin) model_save = booster.enable_lora(model_save, **test_config) model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion) with shared_tempdir() as tempdir: lora_ckpt_path = os.path.join(tempdir, "ckpt") booster.save_lora_as_pretrained(model_save, lora_ckpt_path) dist.barrier() # The Lora checkpoint should be small in size checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) assert checkpoint_size_mb < 1 model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config) model_load, _, _, _, _ = booster.boost(model_load) check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) # test fwd bwd correctness test_model = model_load model_copy = copy.deepcopy(model_load) data = data_gen_fn() data = { k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } output = test_model(**data) output = output_transform_fn(output) loss = criterion(output) booster.backward(loss, optimizer) optimizer.clip_grad_by_norm(1.0) optimizer.step() for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()): if "lora_" in n1: # lora modules require gradients, thus updated assert p1.requires_grad assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) else: if not p1.requires_grad: torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) def run_lora_test(): sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): task_type = None if name == "transformers_llama_for_casual_lm": task_type = "CAUSAL_LM" if name == "transformers_llama_for_sequence_classification": task_type = "SEQ_CLS" check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) def run_dist(rank, world_size, port): colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_lora_test() @rerun_if_address_is_in_use() def test_torch_ddp_lora(): spawn(run_dist, 2)