[gemini] gemini supports lazy init (#3379)

* [gemini] fix nvme optimizer init

* [gemini] gemini supports lazy init

* [gemini] add init example

* [gemini] add fool model

* [zero] update gemini ddp

* [zero] update init example

* add chunk method

* add chunk method

* [lazyinit] fix lazy tensor tolist

* [gemini] fix buffer materialization

* [misc] remove useless file

* [booster] update gemini plugin

* [test] update gemini plugin test

* [test] fix gemini plugin test

* [gemini] fix import

* [gemini] fix import

* [lazyinit] use new metatensor

* [lazyinit] use new metatensor

* [lazyinit] fix __set__ method
pull/3525/head
Hongxin Liu 2023-04-12 16:03:25 +08:00 committed by GitHub
parent 366a035552
commit 152239bbfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 72 deletions

View File

@ -2,8 +2,6 @@ import torch
import torch.distributed as dist
from packaging import version
aten = torch.ops.aten
__all__ = [
"_TorchFactoryMethod",
"_TorchOverrideableFactoryMethod",
@ -51,6 +49,7 @@ _DistCommMethod = [
]
if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [

View File

@ -16,10 +16,8 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
from colossalai.zero.gemini.memory_tracer import MemStats
from .plugin_base import Plugin
@ -27,50 +25,6 @@ from .plugin_base import Plugin
__all__ = ['GeminiPlugin']
def convert_to_colo_param(module: nn.Module) -> None:
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.
Args:
module (nn.Module): Module to be converted.
"""
converted_modules = set() # handle shared modules
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
def convert_recursively(m: nn.Module):
for child in m.children():
if child not in converted_modules:
converted_modules.add(child)
convert_recursively(child)
for name, p in m.named_parameters(recurse=False):
assert not isinstance(p, ColoParameter)
if p in converted_params:
target = converted_params[p]
else:
target = _convert_to_coloparam(p, p.device, p.dtype)
converted_params[p] = target
setattr(m, name, target)
target.shared_param_modules.append(m)
convert_recursively(module)
# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
module._converted_params = converted_params
def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
"""Replace param in optimizer's group with converted ColoParameter.
Args:
optimizer (Optimizer): Optimizer to be replaced.
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
"""
for group in optimizer.param_groups:
for i, p in enumerate(group['params']):
if p in converted_params:
group['params'][i] = converted_params[p]
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
@ -113,8 +67,6 @@ class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
super().__init__(module)
# TODO(ver217): only support Gemini now
convert_to_colo_param(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
def unwrap(self):
@ -125,8 +77,6 @@ class GeminiModel(ModelWrapper):
class GeminiOptimizer(OptimizerWrapper):
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
replace_param_in_group(optimizer, module.module._converted_params)
del module.module._converted_params
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
super().__init__(optimizer)

View File

@ -1,9 +1,10 @@
import torch
import math
import os
import tempfile
import math
from typing import Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from typing import Optional, List, Dict, Callable
class NVMeOptimizer(torch.optim.Optimizer):
@ -42,8 +43,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
self.offloader = None
self.is_on_nvme: Dict[Parameter, bool] = {}
self.offloaded_numel: int = 0
self.total_numel: int = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
# As param may be not materialized here, these attributes are initalized when the first step
self.total_numel: Optional[int] = None
self.can_offload_numel: Optional[int] = None
self.prefetch_params: List[Parameter] = []
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
@ -77,6 +79,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
self.prefetch_params.append(p)
def _pre_step(self, *state_keys: str) -> None:
if self.total_numel is None:
self.total_numel = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
self._setup_prefetch_params()
if self.offloader is None or len(self.prefetch_params) == 0:
return

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from torch import Tensor
from torch.utils._pytree import tree_map
from colossalai.fx.profiler.tensor import MetaTensor
from colossalai._analyzer._subclasses import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout
@ -37,7 +37,7 @@ _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__']
_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
@ -75,6 +75,12 @@ class _MyTensor(Tensor):
return super().__torch_function__(func, types, args, kwargs)
def _data_tolist(tensor: torch.Tensor) -> list:
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
"""
return tensor.data.tolist()
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.
@ -94,7 +100,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(torch.Tensor.tolist, target)
tensor.tolist = MethodType(_data_tolist, tensor)
return tensor
@ -144,7 +150,7 @@ class LazyTensor(torch.Tensor):
if meta_data is None:
device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
meta_data = MetaTensor(elem, device=device)
elem = meta_data._tensor
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
@ -255,7 +261,7 @@ class LazyTensor(torch.Tensor):
tree_map(cls._replace_with_materialized, args)
tree_map(cls._replace_with_materialized, kwargs)
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__")
or func.__name__ in ('__setitem__', '__set__'))
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS

View File

@ -46,9 +46,10 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
if strict_ddp_flag:
if strict_ddp_flag and type(local_param) is ColoParameter:
return local_param.numel_global()
else:
# if local_param is not ColoParameter, we assume it's replicated
return local_param.numel()
@ -67,11 +68,13 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param):
continue
if strict_ddp_flag:
if strict_ddp_flag or type(param) is not ColoParameter:
# if model is not initialized with ColoInitContext, we assume it's replicated
# TODO(ver217): integrate DTensor
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()

View File

@ -1,7 +1,7 @@
import itertools
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import torch
import torch.distributed as dist
@ -14,6 +14,7 @@ from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.utils.model.experimental import LazyTensor
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
@ -55,7 +56,6 @@ class ZeroDDP(ColoDDP):
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
@ -67,7 +67,6 @@ class ZeroDDP(ColoDDP):
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@ -91,6 +90,8 @@ class ZeroDDP(ColoDDP):
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._cast_buffers()
def _post_forward(self):
"""This function is only triggered for inference.
@ -478,7 +479,8 @@ class ZeroDDP(ColoDDP):
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
self._preprocess_param(p)
assert type(p) is ColoParameter
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
@ -531,10 +533,27 @@ class ZeroDDP(ColoDDP):
def _cast_buffers(self):
for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
Args:
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
"""
if type(p) is ColoParameter:
# model is initialized with ColoInitContext
return
requires_grad = p.requires_grad
if isinstance(p, LazyTensor):
# model is initialized with LazyInitContext
p.materialize()
p.__class__ = ColoParameter
p.__init__(p, requires_grad=requires_grad)
class GeminiDDP(ZeroDDP):

View File

@ -1,21 +1,31 @@
from contextlib import nullcontext
import torch
import torch.distributed as dist
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo
def check_gemini_plugin(early_stop: bool = True):
@parameterize('init_method', ['lazy', 'none', 'colo'])
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
is_support_meta = is_compatible_with_meta()
if not is_support_meta and init_method == 'lazy':
return
from colossalai.utils.model.experimental import LazyInitContext
passed_models = []
failed_info = {} # (model_name, error) pair
@ -40,10 +50,25 @@ def check_gemini_plugin(early_stop: bool = True):
]:
continue
if init_method == 'lazy' and name in [
'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3',
'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0',
'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf',
'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s'
]:
continue
try:
if init_method == 'colo':
ctx = ColoInitContext()
elif init_method == 'lazy':
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
with ctx:
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
@ -76,6 +101,7 @@ def check_gemini_plugin(early_stop: bool = True):
torch.cuda.empty_cache()
if dist.get_rank() == 0:
print(f'Init method: {init_method}')
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])