|
|
|
@ -3,6 +3,7 @@ import os
|
|
|
|
|
import warnings |
|
|
|
|
from functools import partial |
|
|
|
|
from pathlib import Path |
|
|
|
|
from types import MethodType |
|
|
|
|
from typing import Callable, Iterator, List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
|
|
|
|
|
sharded_optimizer_loading_epilogue, |
|
|
|
|
unwrap_optimizer, |
|
|
|
|
) |
|
|
|
|
from colossalai.interface import ModelWrapper, OptimizerWrapper |
|
|
|
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper |
|
|
|
|
from colossalai.utils import get_current_device |
|
|
|
|
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper |
|
|
|
|
from colossalai.zero import LowLevelZeroOptimizer |
|
|
|
|
|
|
|
|
|
from .dp_plugin_base import DPPluginBase |
|
|
|
|
from .torch_ddp_plugin import TorchDDPCheckpointIO |
|
|
|
@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|
|
|
|
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): |
|
|
|
|
|
|
|
|
|
def __init__(self, module: nn.Module, precision: str) -> None: |
|
|
|
|
super().__init__(module) |
|
|
|
|
self.dtype = None |
|
|
|
|
if precision == 'fp16': |
|
|
|
|
self.dtype = torch.float16 |
|
|
|
|
elif precision == 'bf16': |
|
|
|
|
self.dtype = torch.bfloat16 |
|
|
|
|
if self.dtype is not None: |
|
|
|
|
module = module.to(self.dtype) |
|
|
|
|
module = module.to(get_current_device()) |
|
|
|
|
self.module = module |
|
|
|
|
self.convert_fn = None |
|
|
|
|
if self.dtype is not None: |
|
|
|
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) |
|
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
if self.convert_fn is not None: |
|
|
|
|
args = tree_map(self.convert_fn, args) |
|
|
|
|
kwargs = tree_map(self.convert_fn, kwargs) |
|
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
def unwrap(self): |
|
|
|
|
# TODO(ver217): this is a workaround for loading model |
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): |
|
|
|
|
|
|
|
|
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): |
|
|
|
@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
|
|
|
|
|
|
|
|
sharded_optimizer_loading_epilogue(optimizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroModel(ModelWrapper): |
|
|
|
|
|
|
|
|
|
def __init__(self, module: nn.Module, stage: int, precision: str) -> None: |
|
|
|
|
super().__init__(module) |
|
|
|
|
self.dtype = None |
|
|
|
|
if precision == 'fp16': |
|
|
|
|
self.dtype = torch.float16 |
|
|
|
|
elif precision == 'bf16': |
|
|
|
|
self.dtype = torch.bfloat16 |
|
|
|
|
module = zero_model_wrapper(module, zero_stage=stage) |
|
|
|
|
if self.dtype is not None: |
|
|
|
|
module = module.to(self.dtype) |
|
|
|
|
module = module.to(get_current_device()) |
|
|
|
|
self.module = module |
|
|
|
|
self.convert_fn = None |
|
|
|
|
if self.dtype is not None: |
|
|
|
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) |
|
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
if self.convert_fn is not None: |
|
|
|
|
args = tree_map(self.convert_fn, args) |
|
|
|
|
kwargs = tree_map(self.convert_fn, kwargs) |
|
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, |
|
|
|
|
use_safetensors: bool): |
|
|
|
|
assert isinstance(model, LowLevelZeroModel) |
|
|
|
|
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors) |
|
|
|
|
|
|
|
|
|
def save_sharded_model(self, |
|
|
|
|
model: nn.Module, |
|
|
|
|
checkpoint_path: str, |
|
|
|
|
gather_dtensor: bool = True, |
|
|
|
|
prefix: Optional[str] = None, |
|
|
|
|
max_shard_size: int = 1024, |
|
|
|
|
use_safetensors: bool = False): |
|
|
|
|
assert isinstance(model, LowLevelZeroModel) |
|
|
|
|
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, |
|
|
|
|
use_safetensors) |
|
|
|
|
|
|
|
|
|
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True): |
|
|
|
|
assert isinstance(model, LowLevelZeroModel) |
|
|
|
|
super().load_unsharded_model(model.module, checkpoint, strict) |
|
|
|
|
model.update_master_params() |
|
|
|
|
|
|
|
|
|
def load_sharded_model(self, |
|
|
|
|
model: LowLevelZeroModel, |
|
|
|
|
checkpoint_index_file: Path, |
|
|
|
|
strict: bool = False, |
|
|
|
|
use_safetensors: bool = False, |
|
|
|
|
load_sub_module: bool = True): |
|
|
|
|
assert isinstance(model, LowLevelZeroModel) |
|
|
|
|
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) |
|
|
|
|
model.update_master_params() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroPlugin(DPPluginBase): |
|
|
|
@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
super().__init__() |
|
|
|
|
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' |
|
|
|
|
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' |
|
|
|
|
|
|
|
|
|
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now' |
|
|
|
|
self.stage = stage |
|
|
|
|
self.precision = precision |
|
|
|
|
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, |
|
|
|
|
communication_dtype=communication_dtype, |
|
|
|
|
overlap_communication=overlap_communication, |
|
|
|
|
cpu_offload=cpu_offload) |
|
|
|
|
self.optim_kwargs = dict(initial_scale=initial_scale, |
|
|
|
|
growth_factor=growth_factor, |
|
|
|
|
backoff_factor=backoff_factor, |
|
|
|
|
growth_interval=growth_interval, |
|
|
|
|
hysteresis=hysteresis, |
|
|
|
|
min_scale=min_scale, |
|
|
|
|
max_scale=max_scale, |
|
|
|
|
max_norm=max_norm, |
|
|
|
|
norm_type=norm_type) |
|
|
|
|
self.zero_optim_kwargs = dict( |
|
|
|
|
initial_scale=initial_scale, |
|
|
|
|
growth_factor=growth_factor, |
|
|
|
|
backoff_factor=backoff_factor, |
|
|
|
|
growth_interval=growth_interval, |
|
|
|
|
hysteresis=hysteresis, |
|
|
|
|
min_scale=min_scale, |
|
|
|
|
max_scale=max_scale, |
|
|
|
|
clip_grad_norm=max_norm, |
|
|
|
|
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, |
|
|
|
|
communication_dtype=communication_dtype, |
|
|
|
|
overlap_communication=overlap_communication, |
|
|
|
|
cpu_offload=cpu_offload, |
|
|
|
|
partition_grad=(stage == 2), |
|
|
|
|
) |
|
|
|
|
self.verbose = verbose |
|
|
|
|
|
|
|
|
|
# set class name with stage, for better error message |
|
|
|
@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: |
|
|
|
|
|
|
|
|
|
if not isinstance(model, ModelWrapper): |
|
|
|
|
model = LowLevelZeroModel(model, self.stage, self.precision) |
|
|
|
|
model = LowLevelZeroModel(model, self.precision) |
|
|
|
|
|
|
|
|
|
if optimizer is not None and \ |
|
|
|
|
not isinstance(optimizer, OptimizerWrapper): |
|
|
|
|
optimizer = zero_optim_wrapper(model.unwrap(), |
|
|
|
|
optimizer, |
|
|
|
|
optim_config=self.zero_optim_config, |
|
|
|
|
**self.optim_kwargs, |
|
|
|
|
verbose=self.verbose) |
|
|
|
|
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer, |
|
|
|
|
**self.zero_optim_kwargs, |
|
|
|
|
verbose=self.verbose) |
|
|
|
|
# inject update_master_params |
|
|
|
|
model.update_master_params = MethodType(optimizer.update_master_params, model) |
|
|
|
|
|
|
|
|
|
return model, optimizer, criterion, dataloader, lr_scheduler |
|
|
|
|
|
|
|
|
|