|
|
|
@ -1,9 +1,11 @@
|
|
|
|
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
|
|
|
@ -134,7 +136,11 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|
|
|
|
assert isinstance(
|
|
|
|
|
peft_model, PeftModel
|
|
|
|
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
|
|
|
|
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
|
|
|
|
return peft_model.save_pretrained(
|
|
|
|
|
checkpoint,
|
|
|
|
|
safe_serialization=use_safetensors,
|
|
|
|
|
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TorchDDPModel(ModelWrapper):
|
|
|
|
|