mirror of https://github.com/hpcaitech/ColossalAI
[LowLevelZero] low level zero support lora (#5153)
* low level zero support lora low level zero support lora * add checkpoint test * add checkpoint test * fix * fix * fix * fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * test ci * git # This is a combination of 3 commits. Update low_level_zero_plugin.py Update low_level_zero_plugin.py fix fix fix * fix naming fix naming fix naming fixpull/5001/merge
parent
c5fd4aa6e8
commit
cabc1286ca
|
@ -1,12 +1,15 @@
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
import enum
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
@ -41,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||||
|
|
||||||
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
||||||
|
|
||||||
|
class OptimizerParamCheckState(enum.Enum):
|
||||||
|
ORIGIN_PARAM_FINDED = 0
|
||||||
|
ORIGIN_PARAM_NOT_FIND = -1
|
||||||
|
LORA_PARM_EXISTED = -2
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||||
|
@ -208,6 +216,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||||
model.update_master_params()
|
model.update_master_params()
|
||||||
|
|
||||||
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||||
|
if os.path.isfile(checkpoint):
|
||||||
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
|
return
|
||||||
|
from peft import PeftModel
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
|
peft_model = model.unwrap()
|
||||||
|
assert isinstance(
|
||||||
|
peft_model, PeftModel
|
||||||
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
|
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroPlugin(DPPluginBase):
|
class LowLevelZeroPlugin(DPPluginBase):
|
||||||
"""
|
"""
|
||||||
|
@ -287,6 +307,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
)
|
)
|
||||||
|
self.lora_enabled = False
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# set class name with stage, for better error message
|
# set class name with stage, for better error message
|
||||||
|
@ -310,6 +331,66 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
def supported_devices(self) -> List[str]:
|
def supported_devices(self) -> List[str]:
|
||||||
return ["cuda"]
|
return ["cuda"]
|
||||||
|
|
||||||
|
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def enable_lora(
|
||||||
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||||
|
) -> nn.Module:
|
||||||
|
from peft import PeftModel, get_peft_model
|
||||||
|
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
||||||
|
self.lora_enabled = True
|
||||||
|
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||||
|
|
||||||
|
if pretrained_dir is None:
|
||||||
|
peft_model = get_peft_model(model, lora_config)
|
||||||
|
else:
|
||||||
|
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||||
|
return peft_model
|
||||||
|
|
||||||
|
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
|
||||||
|
origin_param_id = id(origin_param)
|
||||||
|
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||||
|
for p in param_group['params']:
|
||||||
|
if id(p) == origin_param_id:
|
||||||
|
return group_id
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
|
||||||
|
origin_param_id = id(origin_param)
|
||||||
|
lora_param_id = id(lora_param)
|
||||||
|
target_group_id = None
|
||||||
|
for group_id, param_group in enumerate(optimizer.param_groups):
|
||||||
|
for p in param_group['params']:
|
||||||
|
if id(p) == lora_param_id:
|
||||||
|
# check if the lora parameter exists.
|
||||||
|
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
|
||||||
|
if id(p) == origin_param_id:
|
||||||
|
target_group_id = group_id
|
||||||
|
if target_group_id is not None:
|
||||||
|
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
||||||
|
else:
|
||||||
|
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
|
||||||
|
|
||||||
|
def add_lora_params_to_optimizer(self, model, optimizer):
|
||||||
|
""" add lora parameters to optimizer """
|
||||||
|
name2param= {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
name2param[name] = param
|
||||||
|
|
||||||
|
for name, param in name2param.items():
|
||||||
|
if 'lora_A' in name or 'lora_B' in name:
|
||||||
|
origin_key = name.replace("lora_A.", "")
|
||||||
|
origin_key = origin_key.replace("lora_B.", "")
|
||||||
|
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
|
||||||
|
origin_param = name2param[origin_key]
|
||||||
|
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
||||||
|
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
||||||
|
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
|
||||||
|
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
|
||||||
|
optimizer.param_groups[group_id]['params'].append(param)
|
||||||
|
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
@ -318,6 +399,13 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
dataloader: Optional[DataLoader] = None,
|
dataloader: Optional[DataLoader] = None,
|
||||||
lr_scheduler: Optional[LRScheduler] = None,
|
lr_scheduler: Optional[LRScheduler] = None,
|
||||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
if self.lora_enabled:
|
||||||
|
from peft import PeftModel
|
||||||
|
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
||||||
|
if optimizer is not None:
|
||||||
|
self.add_lora_params_to_optimizer(model, optimizer)
|
||||||
|
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
model = LowLevelZeroModel(model, self.precision)
|
model = LowLevelZeroModel(model, self.precision)
|
||||||
|
|
||||||
|
@ -339,8 +427,3 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||||
return optimizer.no_sync()
|
return optimizer.no_sync()
|
||||||
|
|
||||||
def enable_lora(
|
|
||||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
|
||||||
) -> nn.Module:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
|
@ -44,6 +44,20 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
|
||||||
|
|
||||||
return unpickle
|
return unpickle
|
||||||
|
|
||||||
|
def check_for_nccl_backend(group):
|
||||||
|
|
||||||
|
pg = group or c10d._get_default_group()
|
||||||
|
# Gate PG wrapper check on Gloo availability.
|
||||||
|
if c10d._GLOO_AVAILABLE:
|
||||||
|
# It is not expected for PG to be wrapped many times, but support it just
|
||||||
|
# in case
|
||||||
|
while isinstance(pg, c10d._ProcessGroupWrapper):
|
||||||
|
pg = pg.wrapped_pg
|
||||||
|
|
||||||
|
return (
|
||||||
|
c10d.is_nccl_available() and
|
||||||
|
pg.name() == c10d.Backend.NCCL
|
||||||
|
)
|
||||||
|
|
||||||
def _broadcast_object_list(
|
def _broadcast_object_list(
|
||||||
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
|
||||||
|
@ -65,7 +79,7 @@ def _broadcast_object_list(
|
||||||
c10d._warn_not_in_group("broadcast_object_list")
|
c10d._warn_not_in_group("broadcast_object_list")
|
||||||
return
|
return
|
||||||
|
|
||||||
is_nccl_backend = c10d._check_for_nccl_backend(group)
|
is_nccl_backend = check_for_nccl_backend(group)
|
||||||
current_device = None
|
current_device = None
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
|
|
|
@ -82,6 +82,9 @@ class GradientStore(BaseStore):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
grad_list = []
|
grad_list = []
|
||||||
|
# When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients.
|
||||||
|
if group_id not in self._grads_of_params.keys():
|
||||||
|
return grad_list
|
||||||
for param_grads in self._grads_of_params[group_id].values():
|
for param_grads in self._grads_of_params[group_id].values():
|
||||||
grad_list.append(param_grads[self._working_index])
|
grad_list.append(param_grads[self._working_index])
|
||||||
|
|
||||||
|
|
|
@ -18,5 +18,5 @@ SentencePiece
|
||||||
ninja
|
ninja
|
||||||
flash_attn==2.0.5
|
flash_attn==2.0.5
|
||||||
datasets
|
datasets
|
||||||
peft
|
peft>=0.7.1
|
||||||
#auto-gptq now not support torch1.12
|
#auto-gptq now not support torch1.12
|
||||||
|
|
|
@ -14,3 +14,4 @@ einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
google
|
google
|
||||||
protobuf
|
protobuf
|
||||||
|
peft>=0.7.1
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, Iterator, List, Tuple, Union
|
from typing import Callable, Iterator, List, Tuple, Union, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -51,6 +51,12 @@ class DPPluginWrapper(DPPluginBase):
|
||||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def check_dataloader_sharding():
|
def check_dataloader_sharding():
|
||||||
plugin = DPPluginWrapper()
|
plugin = DPPluginWrapper()
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
@ -18,12 +19,16 @@ _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"]
|
||||||
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
|
_STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
|
||||||
|
|
||||||
|
|
||||||
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
|
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
model = model_fn()
|
model = model_fn()
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
if lora_config is not None:
|
||||||
|
model = booster.enable_lora(model, lora_config=lora_config)
|
||||||
|
|
||||||
criterion = lambda x: x.mean()
|
criterion = lambda x: x.mean()
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
@ -43,6 +48,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return repr(e)
|
return repr(e)
|
||||||
|
# raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@parameterize("stage", [2])
|
@parameterize("stage", [2])
|
||||||
|
@ -81,10 +88,41 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
||||||
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("stage", [2])
|
||||||
|
@parameterize("model_name", ["transformers_llama"])
|
||||||
|
def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
|
||||||
|
passed_models = []
|
||||||
|
failed_info = {} # (model_name, error) pair
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry(model_name)
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
task_type = None
|
||||||
|
if name == "transformers_llama_for_casual_lm":
|
||||||
|
task_type = "CAUSAL_LM"
|
||||||
|
if name == "transformers_llama_for_sequence_classification":
|
||||||
|
task_type = "SEQ_CLS"
|
||||||
|
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||||
|
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if err is None:
|
||||||
|
passed_models.append(name)
|
||||||
|
else:
|
||||||
|
failed_info[name] = err
|
||||||
|
if early_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
|
||||||
|
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||||
|
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
check_low_level_zero_plugin(early_stop=early_stop)
|
check_low_level_zero_plugin(early_stop=early_stop)
|
||||||
|
check_low_level_zero_lora(early_stop=early_stop)
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -2,6 +2,9 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torchvision.models import resnet18
|
from torchvision.models import resnet18
|
||||||
from utils import shared_tempdir
|
from utils import shared_tempdir
|
||||||
|
from typing import Optional
|
||||||
|
from peft import LoraConfig
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
@ -15,6 +18,7 @@ from colossalai.testing import (
|
||||||
spawn,
|
spawn,
|
||||||
)
|
)
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
# stage 1 and 2 process the optimizer/mode the same way
|
# stage 1 and 2 process the optimizer/mode the same way
|
||||||
|
@ -69,9 +73,103 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
|
||||||
|
new_plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5, cpu_offload=offload)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
new_booster = Booster(plugin=new_plugin)
|
||||||
|
model = model_fn()
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
new_model = deepcopy(model)
|
||||||
|
new_optimizer = HybridAdam(new_model.parameters(), lr=1e-3)
|
||||||
|
model = booster.enable_lora(model, lora_config=lora_config)
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
|
output = model(**data)
|
||||||
|
output = output_transform_fn(output)
|
||||||
|
output_key = list(output.keys())[0]
|
||||||
|
loss = criterion(output[output_key])
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
with shared_tempdir() as tempdir:
|
||||||
|
model_ckpt_path = f"{tempdir}/model"
|
||||||
|
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||||
|
|
||||||
|
booster.save_lora_as_pretrained(model, model_ckpt_path)
|
||||||
|
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
|
||||||
|
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
|
||||||
|
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
|
||||||
|
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||||
|
|
||||||
|
# check master weight
|
||||||
|
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||||
|
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||||
|
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
|
||||||
|
assert p_id in working_param_id_set
|
||||||
|
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
|
||||||
|
padding = new_optimizer._param_store.get_param_padding_size(working_param)
|
||||||
|
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
|
||||||
|
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
|
||||||
|
assert torch.equal(
|
||||||
|
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||||
|
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# return repr(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@clear_cache_before_run()
|
||||||
|
@parameterize("stage", [2])
|
||||||
|
@parameterize("shard", [True, False])
|
||||||
|
@parameterize("offload", [False, True])
|
||||||
|
@parameterize("model_name", ["transformers_llama"])
|
||||||
|
def check_low_level_zero_lora_checkpointIO(stage: int, shard: bool, offload: bool, model_name: str, early_stop: bool = True):
|
||||||
|
passed_models = []
|
||||||
|
failed_info = {} # (model_name, error) pair
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry(model_name)
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
if name != "transformers_llama":
|
||||||
|
continue
|
||||||
|
task_type = None
|
||||||
|
if name == "transformers_llama_for_casual_lm":
|
||||||
|
task_type = "CAUSAL_LM"
|
||||||
|
if name == "transformers_llama_for_sequence_classification":
|
||||||
|
task_type = "SEQ_CLS"
|
||||||
|
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||||
|
err = run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lora_config)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if err is None:
|
||||||
|
passed_models.append(name)
|
||||||
|
else:
|
||||||
|
failed_info[name] = err
|
||||||
|
if early_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
|
||||||
|
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
|
||||||
|
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
check_low_level_zero_checkpointIO()
|
check_low_level_zero_checkpointIO()
|
||||||
|
check_low_level_zero_lora_checkpointIO()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue