You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/booster/plugin/torch_ddp_plugin.py

285 lines
10 KiB

from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
from .dp_plugin_base import DPPluginBase
__all__ = ["TorchDDPPlugin"]
class TorchDDPCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(
self,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(
model.unwrap(),
checkpoint_path,
gather_dtensor,
prefix,
max_shard_size,
use_safetensors,
use_async=use_async,
)
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: str,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to checkpoint directory.
"""
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
class TorchDDPPlugin(DPPluginBase):
"""
Plugin for PyTorch DDP.
```python
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
model, train_dataset, optimizer, criterion = ...
plugin = TorchDDPPlugin()
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
booster = Booster(plugin=plugin)
model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False.
check_reduction (bool, optional): Whether to check reduction. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
static_graph (bool, optional): Whether to use static graph. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
"""
def __init__(
self,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
fp8_communication: bool = False,
) -> None:
super().__init__()
self.ddp_kwargs = dict(
broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph,
)
self.fp8_communication = fp8_communication
def support_no_sync(self) -> bool:
return True
def support_lora(self) -> bool:
return True
def control_precision(self) -> bool:
return False
def supported_precisions(self) -> List[str]:
return ["fp16", "fp16_apex", "bf16", "fp8"]
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ["cuda", "npu"]
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
# cast model to cuda
model = model.to(get_current_device())
# convert model to sync bn
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with PyTorch DDP
model = TorchDDPModel(model, **self.ddp_kwargs)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
if self.fp8_communication:
from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async
model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async)
return model, optimizer, criterion, dataloader, lr_scheduler
def control_checkpoint_io(self) -> bool:
return True
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()
def enable_lora(
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model
if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None:
return get_peft_model(model, lora_config)
else:
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)