mirror of https://github.com/hpcaitech/ColossalAI
[zero] hotfix master param sync (#4618)
* [zero] add method to update master params * [zero] update zero plugin * [plugin] update low level zero pluginpull/4612/head^2
parent
aaeb520ce3
commit
807e01a4ba
|
@ -3,6 +3,7 @@ import os
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import MethodType
|
||||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
|
||||||
sharded_optimizer_loading_epilogue,
|
sharded_optimizer_loading_epilogue,
|
||||||
unwrap_optimizer,
|
unwrap_optimizer,
|
||||||
)
|
)
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
from .dp_plugin_base import DPPluginBase
|
from .dp_plugin_base import DPPluginBase
|
||||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||||
|
@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||||
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||||
|
|
||||||
|
|
||||||
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||||
|
super().__init__(module)
|
||||||
|
self.dtype = None
|
||||||
|
if precision == 'fp16':
|
||||||
|
self.dtype = torch.float16
|
||||||
|
elif precision == 'bf16':
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
if self.dtype is not None:
|
||||||
|
module = module.to(self.dtype)
|
||||||
|
module = module.to(get_current_device())
|
||||||
|
self.module = module
|
||||||
|
self.convert_fn = None
|
||||||
|
if self.dtype is not None:
|
||||||
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.convert_fn is not None:
|
||||||
|
args = tree_map(self.convert_fn, args)
|
||||||
|
kwargs = tree_map(self.convert_fn, kwargs)
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def unwrap(self):
|
||||||
|
# TODO(ver217): this is a workaround for loading model
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||||
|
@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer)
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
|
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
|
||||||
|
use_safetensors: bool):
|
||||||
|
assert isinstance(model, LowLevelZeroModel)
|
||||||
|
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
|
||||||
|
|
||||||
class LowLevelZeroModel(ModelWrapper):
|
def save_sharded_model(self,
|
||||||
|
model: nn.Module,
|
||||||
|
checkpoint_path: str,
|
||||||
|
gather_dtensor: bool = True,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
max_shard_size: int = 1024,
|
||||||
|
use_safetensors: bool = False):
|
||||||
|
assert isinstance(model, LowLevelZeroModel)
|
||||||
|
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
|
||||||
|
use_safetensors)
|
||||||
|
|
||||||
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
|
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
|
||||||
super().__init__(module)
|
assert isinstance(model, LowLevelZeroModel)
|
||||||
self.dtype = None
|
super().load_unsharded_model(model.module, checkpoint, strict)
|
||||||
if precision == 'fp16':
|
model.update_master_params()
|
||||||
self.dtype = torch.float16
|
|
||||||
elif precision == 'bf16':
|
|
||||||
self.dtype = torch.bfloat16
|
|
||||||
module = zero_model_wrapper(module, zero_stage=stage)
|
|
||||||
if self.dtype is not None:
|
|
||||||
module = module.to(self.dtype)
|
|
||||||
module = module.to(get_current_device())
|
|
||||||
self.module = module
|
|
||||||
self.convert_fn = None
|
|
||||||
if self.dtype is not None:
|
|
||||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def load_sharded_model(self,
|
||||||
if self.convert_fn is not None:
|
model: LowLevelZeroModel,
|
||||||
args = tree_map(self.convert_fn, args)
|
checkpoint_index_file: Path,
|
||||||
kwargs = tree_map(self.convert_fn, kwargs)
|
strict: bool = False,
|
||||||
return super().forward(*args, **kwargs)
|
use_safetensors: bool = False,
|
||||||
|
load_sub_module: bool = True):
|
||||||
|
assert isinstance(model, LowLevelZeroModel)
|
||||||
|
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||||
|
model.update_master_params()
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroPlugin(DPPluginBase):
|
class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
||||||
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
||||||
|
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
self.zero_optim_kwargs = dict(
|
||||||
communication_dtype=communication_dtype,
|
initial_scale=initial_scale,
|
||||||
overlap_communication=overlap_communication,
|
growth_factor=growth_factor,
|
||||||
cpu_offload=cpu_offload)
|
backoff_factor=backoff_factor,
|
||||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
growth_interval=growth_interval,
|
||||||
growth_factor=growth_factor,
|
hysteresis=hysteresis,
|
||||||
backoff_factor=backoff_factor,
|
min_scale=min_scale,
|
||||||
growth_interval=growth_interval,
|
max_scale=max_scale,
|
||||||
hysteresis=hysteresis,
|
clip_grad_norm=max_norm,
|
||||||
min_scale=min_scale,
|
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
||||||
max_scale=max_scale,
|
communication_dtype=communication_dtype,
|
||||||
max_norm=max_norm,
|
overlap_communication=overlap_communication,
|
||||||
norm_type=norm_type)
|
cpu_offload=cpu_offload,
|
||||||
|
partition_grad=(stage == 2),
|
||||||
|
)
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# set class name with stage, for better error message
|
# set class name with stage, for better error message
|
||||||
|
@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
model = LowLevelZeroModel(model, self.stage, self.precision)
|
model = LowLevelZeroModel(model, self.precision)
|
||||||
|
|
||||||
if optimizer is not None and \
|
if optimizer is not None and \
|
||||||
not isinstance(optimizer, OptimizerWrapper):
|
not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer = zero_optim_wrapper(model.unwrap(),
|
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
|
||||||
optimizer,
|
**self.zero_optim_kwargs,
|
||||||
optim_config=self.zero_optim_config,
|
verbose=self.verbose)
|
||||||
**self.optim_kwargs,
|
# inject update_master_params
|
||||||
verbose=self.verbose)
|
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||||
|
|
||||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .model import ModelWrapper
|
from .model import AMPModelMixin, ModelWrapper
|
||||||
from .optimizer import OptimizerWrapper
|
from .optimizer import OptimizerWrapper
|
||||||
|
|
||||||
__all__ = ['OptimizerWrapper', 'ModelWrapper']
|
__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
|
||||||
|
|
|
@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.module(*args, **kwargs)
|
return self.module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AMPModelMixin:
|
||||||
|
"""This mixin class defines the interface for AMP training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def update_master_params(self):
|
||||||
|
"""
|
||||||
|
Update the master parameters for AMP training.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
ret_block_size += current_block_size
|
ret_block_size += current_block_size
|
||||||
|
|
||||||
yield ret_block, ret_block_size
|
yield ret_block, ret_block_size
|
||||||
|
|
||||||
|
def update_master_params(self, model: nn.Module) -> None:
|
||||||
|
"""Update master params from working params
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to update master params
|
||||||
|
"""
|
||||||
|
for p in model.parameters():
|
||||||
|
p_id = id(p)
|
||||||
|
if p_id in self._param_store.working_to_master_param:
|
||||||
|
master_param = self._param_store.working_to_master_param[p_id]
|
||||||
|
padding_size = self._param_store.get_param_padding_size(p)
|
||||||
|
working_param = p.data.view(-1)
|
||||||
|
if padding_size > 0:
|
||||||
|
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||||
|
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||||
|
|
|
@ -14,6 +14,7 @@ from colossalai.testing import (
|
||||||
rerun_if_address_is_in_use,
|
rerun_if_address_is_in_use,
|
||||||
spawn,
|
spawn,
|
||||||
)
|
)
|
||||||
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
|
|
||||||
# stage 1 and 2 process the optimizer/mode the same way
|
# stage 1 and 2 process the optimizer/mode the same way
|
||||||
|
@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
||||||
|
|
||||||
booster.load_model(new_model, model_ckpt_path)
|
booster.load_model(new_model, model_ckpt_path)
|
||||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||||
|
# check master weight
|
||||||
|
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||||
|
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||||
|
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
|
||||||
|
assert p_id in working_param_id_set
|
||||||
|
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
|
||||||
|
padding = new_optimizer._param_store.get_param_padding_size(working_param)
|
||||||
|
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
|
||||||
|
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
|
||||||
|
assert torch.equal(working_shard,
|
||||||
|
master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))
|
||||||
|
|
||||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||||
|
|
Loading…
Reference in New Issue