mirror of https://github.com/hpcaitech/ColossalAI
[gemini] gemini supports lazy init (#3379)
* [gemini] fix nvme optimizer init * [gemini] gemini supports lazy init * [gemini] add init example * [gemini] add fool model * [zero] update gemini ddp * [zero] update init example * add chunk method * add chunk method * [lazyinit] fix lazy tensor tolist * [gemini] fix buffer materialization * [misc] remove useless file * [booster] update gemini plugin * [test] update gemini plugin test * [test] fix gemini plugin test * [gemini] fix import * [gemini] fix import * [lazyinit] use new metatensor * [lazyinit] use new metatensor * [lazyinit] fix __set__ methodpull/3525/head
parent
366a035552
commit
152239bbfa
|
@ -2,8 +2,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
__all__ = [
|
||||
"_TorchFactoryMethod",
|
||||
"_TorchOverrideableFactoryMethod",
|
||||
|
@ -51,6 +49,7 @@ _DistCommMethod = [
|
|||
]
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
aten = torch.ops.aten
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
_AliasATen = [
|
||||
|
|
|
@ -16,10 +16,8 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
|||
from colossalai.checkpoint_io.utils import save_state_dict
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from .plugin_base import Plugin
|
||||
|
@ -27,50 +25,6 @@ from .plugin_base import Plugin
|
|||
__all__ = ['GeminiPlugin']
|
||||
|
||||
|
||||
def convert_to_colo_param(module: nn.Module) -> None:
|
||||
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Module to be converted.
|
||||
"""
|
||||
converted_modules = set() # handle shared modules
|
||||
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
|
||||
|
||||
def convert_recursively(m: nn.Module):
|
||||
for child in m.children():
|
||||
if child not in converted_modules:
|
||||
converted_modules.add(child)
|
||||
convert_recursively(child)
|
||||
|
||||
for name, p in m.named_parameters(recurse=False):
|
||||
assert not isinstance(p, ColoParameter)
|
||||
if p in converted_params:
|
||||
target = converted_params[p]
|
||||
else:
|
||||
target = _convert_to_coloparam(p, p.device, p.dtype)
|
||||
converted_params[p] = target
|
||||
setattr(m, name, target)
|
||||
target.shared_param_modules.append(m)
|
||||
|
||||
convert_recursively(module)
|
||||
|
||||
# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
|
||||
module._converted_params = converted_params
|
||||
|
||||
|
||||
def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
|
||||
"""Replace param in optimizer's group with converted ColoParameter.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Optimizer to be replaced.
|
||||
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for i, p in enumerate(group['params']):
|
||||
if p in converted_params:
|
||||
group['params'][i] = converted_params[p]
|
||||
|
||||
|
||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -113,8 +67,6 @@ class GeminiModel(ModelWrapper):
|
|||
|
||||
def __init__(self, module: nn.Module, gemini_config: dict) -> None:
|
||||
super().__init__(module)
|
||||
# TODO(ver217): only support Gemini now
|
||||
convert_to_colo_param(module)
|
||||
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
|
||||
|
||||
def unwrap(self):
|
||||
|
@ -125,8 +77,6 @@ class GeminiModel(ModelWrapper):
|
|||
class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
|
||||
replace_param_in_group(optimizer, module.module._converted_params)
|
||||
del module.module._converted_params
|
||||
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
|
||||
super().__init__(optimizer)
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import math
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional, List, Dict, Callable
|
||||
|
||||
|
||||
class NVMeOptimizer(torch.optim.Optimizer):
|
||||
|
@ -42,8 +43,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
|||
self.offloader = None
|
||||
self.is_on_nvme: Dict[Parameter, bool] = {}
|
||||
self.offloaded_numel: int = 0
|
||||
self.total_numel: int = self._get_numel()
|
||||
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
|
||||
# As param may be not materialized here, these attributes are initalized when the first step
|
||||
self.total_numel: Optional[int] = None
|
||||
self.can_offload_numel: Optional[int] = None
|
||||
|
||||
self.prefetch_params: List[Parameter] = []
|
||||
self.param_to_prefetch_idx: Dict[Parameter, int] = {}
|
||||
|
@ -77,6 +79,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
|
|||
self.prefetch_params.append(p)
|
||||
|
||||
def _pre_step(self, *state_keys: str) -> None:
|
||||
if self.total_numel is None:
|
||||
self.total_numel = self._get_numel()
|
||||
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
|
||||
self._setup_prefetch_params()
|
||||
if self.offloader is None or len(self.prefetch_params) == 0:
|
||||
return
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
|
@ -37,7 +37,7 @@ _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
|||
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
|
||||
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
|
||||
# These ops cannot be unwrapped using .data
|
||||
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
|
||||
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__']
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'FloatTensor': torch.float,
|
||||
|
@ -75,6 +75,12 @@ class _MyTensor(Tensor):
|
|||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
def _data_tolist(tensor: torch.Tensor) -> list:
|
||||
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
|
||||
"""
|
||||
return tensor.data.tolist()
|
||||
|
||||
|
||||
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||
|
||||
|
@ -94,7 +100,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
|||
tensor.requires_grad = target.requires_grad
|
||||
# subclass of torch.Tensor does not have tolist() method
|
||||
# overwrite this method after materialization or distribution
|
||||
tensor.tolist = MethodType(torch.Tensor.tolist, target)
|
||||
tensor.tolist = MethodType(_data_tolist, tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
|
@ -144,7 +150,7 @@ class LazyTensor(torch.Tensor):
|
|||
if meta_data is None:
|
||||
device = kwargs.get('device', 'cpu')
|
||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
meta_data = MetaTensor(elem, device=device)
|
||||
elem = meta_data._tensor
|
||||
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||||
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
|
||||
|
@ -255,7 +261,7 @@ class LazyTensor(torch.Tensor):
|
|||
tree_map(cls._replace_with_materialized, args)
|
||||
tree_map(cls._replace_with_materialized, kwargs)
|
||||
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||||
or func.__name__ == "__setitem__")
|
||||
or func.__name__ in ('__setitem__', '__set__'))
|
||||
|
||||
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
|
||||
|
||||
|
|
|
@ -46,9 +46,10 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
|||
|
||||
|
||||
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
|
||||
if strict_ddp_flag:
|
||||
if strict_ddp_flag and type(local_param) is ColoParameter:
|
||||
return local_param.numel_global()
|
||||
else:
|
||||
# if local_param is not ColoParameter, we assume it's replicated
|
||||
return local_param.numel()
|
||||
|
||||
|
||||
|
@ -67,11 +68,13 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
|||
"""
|
||||
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||
for param in param_order.generate():
|
||||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if is_ddp_ignored(param):
|
||||
continue
|
||||
|
||||
if strict_ddp_flag:
|
||||
if strict_ddp_flag or type(param) is not ColoParameter:
|
||||
# if model is not initialized with ColoInitContext, we assume it's replicated
|
||||
# TODO(ver217): integrate DTensor
|
||||
param_key = dist.get_world_size()
|
||||
else:
|
||||
param_key = param.process_group.dp_world_size()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import itertools
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -14,6 +14,7 @@ from colossalai.tensor import ReplicaSpec
|
|||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
from colossalai.utils.model.experimental import LazyTensor
|
||||
|
||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||
from .gemini_hook import GeminiZeROHook
|
||||
|
@ -55,7 +56,6 @@ class ZeroDDP(ColoDDP):
|
|||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False) -> None:
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
|
@ -67,7 +67,6 @@ class ZeroDDP(ColoDDP):
|
|||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
|
@ -91,6 +90,8 @@ class ZeroDDP(ColoDDP):
|
|||
for p_name, p_var in m_var.named_parameters(recurse=False):
|
||||
param_name = m_name + '.' + p_name if m_name else p_name
|
||||
self.name2param[param_name] = p_var
|
||||
super().__init__(module, process_group=ColoProcessGroup())
|
||||
self._cast_buffers()
|
||||
|
||||
def _post_forward(self):
|
||||
"""This function is only triggered for inference.
|
||||
|
@ -478,7 +479,8 @@ class ZeroDDP(ColoDDP):
|
|||
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
|
||||
ddp_pg = ColoProcessGroup()
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
self._preprocess_param(p)
|
||||
assert type(p) is ColoParameter
|
||||
|
||||
# gather sharded parameters in the strict ddp mode
|
||||
if strict_ddp_mode:
|
||||
|
@ -531,10 +533,27 @@ class ZeroDDP(ColoDDP):
|
|||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
if isinstance(buffer, LazyTensor):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
Args:
|
||||
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
|
||||
"""
|
||||
if type(p) is ColoParameter:
|
||||
# model is initialized with ColoInitContext
|
||||
return
|
||||
requires_grad = p.requires_grad
|
||||
if isinstance(p, LazyTensor):
|
||||
# model is initialized with LazyInitContext
|
||||
p.materialize()
|
||||
p.__class__ = ColoParameter
|
||||
p.__init__(p, requires_grad=requires_grad)
|
||||
|
||||
|
||||
class GeminiDDP(ZeroDDP):
|
||||
|
||||
|
|
|
@ -1,21 +1,31 @@
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.fx import is_compatible_with_meta
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero import ColoInitContext
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def check_gemini_plugin(early_stop: bool = True):
|
||||
@parameterize('init_method', ['lazy', 'none', 'colo'])
|
||||
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
|
||||
"""
|
||||
is_support_meta = is_compatible_with_meta()
|
||||
if not is_support_meta and init_method == 'lazy':
|
||||
return
|
||||
|
||||
from colossalai.utils.model.experimental import LazyInitContext
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
|
@ -40,9 +50,24 @@ def check_gemini_plugin(early_stop: bool = True):
|
|||
]:
|
||||
continue
|
||||
|
||||
if init_method == 'lazy' and name in [
|
||||
'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3',
|
||||
'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0',
|
||||
'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf',
|
||||
'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s'
|
||||
]:
|
||||
continue
|
||||
|
||||
try:
|
||||
if init_method == 'colo':
|
||||
ctx = ColoInitContext()
|
||||
elif init_method == 'lazy':
|
||||
ctx = LazyInitContext()
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
|
||||
booster = Booster(plugin=plugin)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
|
@ -76,6 +101,7 @@ def check_gemini_plugin(early_stop: bool = True):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f'Init method: {init_method}')
|
||||
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
|
||||
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
|
||||
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
|
||||
|
|
Loading…
Reference in New Issue