[gemini] gemini supports lazy init (#3379)

* [gemini] fix nvme optimizer init

* [gemini] gemini supports lazy init

* [gemini] add init example

* [gemini] add fool model

* [zero] update gemini ddp

* [zero] update init example

* add chunk method

* add chunk method

* [lazyinit] fix lazy tensor tolist

* [gemini] fix buffer materialization

* [misc] remove useless file

* [booster] update gemini plugin

* [test] update gemini plugin test

* [test] fix gemini plugin test

* [gemini] fix import

* [gemini] fix import

* [lazyinit] use new metatensor

* [lazyinit] use new metatensor

* [lazyinit] fix __set__ method
pull/3525/head
Hongxin Liu 2023-04-12 16:03:25 +08:00 committed by GitHub
parent 366a035552
commit 152239bbfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 72 deletions

View File

@ -2,8 +2,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from packaging import version from packaging import version
aten = torch.ops.aten
__all__ = [ __all__ = [
"_TorchFactoryMethod", "_TorchFactoryMethod",
"_TorchOverrideableFactoryMethod", "_TorchOverrideableFactoryMethod",
@ -51,6 +49,7 @@ _DistCommMethod = [
] ]
if version.parse(torch.__version__) >= version.parse('1.12.0'): if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten = torch.ops.aten
# TODO: dive deep here # TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [ _AliasATen = [

View File

@ -16,10 +16,8 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import save_state_dict from colossalai.checkpoint_io.utils import save_state_dict
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
from .plugin_base import Plugin from .plugin_base import Plugin
@ -27,50 +25,6 @@ from .plugin_base import Plugin
__all__ = ['GeminiPlugin'] __all__ = ['GeminiPlugin']
def convert_to_colo_param(module: nn.Module) -> None:
"""Convert module's paramters to ColoParameter. This is a workaround and will be deprecated when lazy init is compatible with Gemini.
Args:
module (nn.Module): Module to be converted.
"""
converted_modules = set() # handle shared modules
converted_params = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference
def convert_recursively(m: nn.Module):
for child in m.children():
if child not in converted_modules:
converted_modules.add(child)
convert_recursively(child)
for name, p in m.named_parameters(recurse=False):
assert not isinstance(p, ColoParameter)
if p in converted_params:
target = converted_params[p]
else:
target = _convert_to_coloparam(p, p.device, p.dtype)
converted_params[p] = target
setattr(m, name, target)
target.shared_param_modules.append(m)
convert_recursively(module)
# optimizer should replace params in group as well. This attr should be deleted after replacing to avoid memory leak
module._converted_params = converted_params
def replace_param_in_group(optimizer: Optimizer, converted_params: dict) -> None:
"""Replace param in optimizer's group with converted ColoParameter.
Args:
optimizer (Optimizer): Optimizer to be replaced.
converted_params (dict): Mapping between (torch.Tensor, ColoTensor).
"""
for group in optimizer.param_groups:
for i, p in enumerate(group['params']):
if p in converted_params:
group['params'][i] = converted_params[p]
class GeminiCheckpointIO(GeneralCheckpointIO): class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
@ -113,8 +67,6 @@ class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict) -> None: def __init__(self, module: nn.Module, gemini_config: dict) -> None:
super().__init__(module) super().__init__(module)
# TODO(ver217): only support Gemini now
convert_to_colo_param(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config) self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config)
def unwrap(self): def unwrap(self):
@ -125,8 +77,6 @@ class GeminiModel(ModelWrapper):
class GeminiOptimizer(OptimizerWrapper): class GeminiOptimizer(OptimizerWrapper):
def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None: def __init__(self, module: GeminiDDP, optimizer: Optimizer, zero_optim_config: dict, optim_kwargs: dict) -> None:
replace_param_in_group(optimizer, module.module._converted_params)
del module.module._converted_params
optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs) optimizer = zero_optim_wrapper(module, optimizer, optim_config=zero_optim_config, **optim_kwargs)
super().__init__(optimizer) super().__init__(optimizer)

View File

@ -1,9 +1,10 @@
import torch import math
import os import os
import tempfile import tempfile
import math from typing import Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import Optional, List, Dict, Callable
class NVMeOptimizer(torch.optim.Optimizer): class NVMeOptimizer(torch.optim.Optimizer):
@ -42,8 +43,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
self.offloader = None self.offloader = None
self.is_on_nvme: Dict[Parameter, bool] = {} self.is_on_nvme: Dict[Parameter, bool] = {}
self.offloaded_numel: int = 0 self.offloaded_numel: int = 0
self.total_numel: int = self._get_numel() # As param may be not materialized here, these attributes are initalized when the first step
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction) self.total_numel: Optional[int] = None
self.can_offload_numel: Optional[int] = None
self.prefetch_params: List[Parameter] = [] self.prefetch_params: List[Parameter] = []
self.param_to_prefetch_idx: Dict[Parameter, int] = {} self.param_to_prefetch_idx: Dict[Parameter, int] = {}
@ -77,6 +79,9 @@ class NVMeOptimizer(torch.optim.Optimizer):
self.prefetch_params.append(p) self.prefetch_params.append(p)
def _pre_step(self, *state_keys: str) -> None: def _pre_step(self, *state_keys: str) -> None:
if self.total_numel is None:
self.total_numel = self._get_numel()
self.can_offload_numel = math.floor(self.total_numel * self.nvme_offload_fraction)
self._setup_prefetch_params() self._setup_prefetch_params()
if self.offloader is None or len(self.prefetch_params) == 0: if self.offloader is None or len(self.prefetch_params) == 0:
return return

View File

@ -7,7 +7,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.fx.profiler.tensor import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
@ -37,7 +37,7 @@ _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) # If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. # without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data # These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__'] _CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__']
_LEGACY_TENSOR_CONSTRUCTOR = { _LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float, 'FloatTensor': torch.float,
@ -75,6 +75,12 @@ class _MyTensor(Tensor):
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)
def _data_tolist(tensor: torch.Tensor) -> list:
"""tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.
"""
return tensor.data.tolist()
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data. """Convert a lazy tensor's class to target's class, with target's data.
@ -94,7 +100,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
tensor.requires_grad = target.requires_grad tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method # subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution # overwrite this method after materialization or distribution
tensor.tolist = MethodType(torch.Tensor.tolist, target) tensor.tolist = MethodType(_data_tolist, tensor)
return tensor return tensor
@ -144,7 +150,7 @@ class LazyTensor(torch.Tensor):
if meta_data is None: if meta_data is None:
device = kwargs.get('device', 'cpu') device = kwargs.get('device', 'cpu')
elem = func(*args, **{**kwargs, 'device': 'meta'}) elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device) meta_data = MetaTensor(elem, device=device)
elem = meta_data._tensor elem = meta_data._tensor
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
@ -255,7 +261,7 @@ class LazyTensor(torch.Tensor):
tree_map(cls._replace_with_materialized, args) tree_map(cls._replace_with_materialized, args)
tree_map(cls._replace_with_materialized, kwargs) tree_map(cls._replace_with_materialized, kwargs)
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__") or func.__name__ in ('__setitem__', '__set__'))
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS

View File

@ -46,9 +46,10 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool): def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
if strict_ddp_flag: if strict_ddp_flag and type(local_param) is ColoParameter:
return local_param.numel_global() return local_param.numel_global()
else: else:
# if local_param is not ColoParameter, we assume it's replicated
return local_param.numel() return local_param.numel()
@ -67,11 +68,13 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
""" """
params_dict: Dict[int, List[ColoParameter]] = dict() params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate(): for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" # assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param): if is_ddp_ignored(param):
continue continue
if strict_ddp_flag: if strict_ddp_flag or type(param) is not ColoParameter:
# if model is not initialized with ColoInitContext, we assume it's replicated
# TODO(ver217): integrate DTensor
param_key = dist.get_world_size() param_key = dist.get_world_size()
else: else:
param_key = param.process_group.dp_world_size() param_key = param.process_group.dp_world_size()

View File

@ -1,7 +1,7 @@
import itertools import itertools
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -14,6 +14,7 @@ from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.utils.model.experimental import LazyTensor
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook from .gemini_hook import GeminiZeROHook
@ -55,7 +56,6 @@ class ZeroDDP(ColoDDP):
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None: strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
@ -67,7 +67,6 @@ class ZeroDDP(ColoDDP):
self.param2name: Dict[nn.Parameter, str] = dict() self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict() self.name2param: Dict[str, nn.Parameter] = dict()
self._cast_buffers()
self._logger = get_dist_logger() self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_: if self.gemini_manager._premade_memstats_:
@ -91,6 +90,8 @@ class ZeroDDP(ColoDDP):
for p_name, p_var in m_var.named_parameters(recurse=False): for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._cast_buffers()
def _post_forward(self): def _post_forward(self):
"""This function is only triggered for inference. """This function is only triggered for inference.
@ -478,7 +479,8 @@ class ZeroDDP(ColoDDP):
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup() ddp_pg = ColoProcessGroup()
for p in param_order.generate(): for p in param_order.generate():
assert isinstance(p, ColoParameter) self._preprocess_param(p)
assert type(p) is ColoParameter
# gather sharded parameters in the strict ddp mode # gather sharded parameters in the strict ddp mode
if strict_ddp_mode: if strict_ddp_mode:
@ -531,10 +533,27 @@ class ZeroDDP(ColoDDP):
def _cast_buffers(self): def _cast_buffers(self):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
buffer.data = buffer.cuda() buffer.data = buffer.cuda()
if torch.is_floating_point(buffer): if torch.is_floating_point(buffer):
buffer.data = buffer.half() buffer.data = buffer.half()
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
Args:
p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted
"""
if type(p) is ColoParameter:
# model is initialized with ColoInitContext
return
requires_grad = p.requires_grad
if isinstance(p, LazyTensor):
# model is initialized with LazyInitContext
p.materialize()
p.__class__ = ColoParameter
p.__init__(p, requires_grad=requires_grad)
class GeminiDDP(ZeroDDP): class GeminiDDP(ZeroDDP):

View File

@ -1,21 +1,31 @@
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
def check_gemini_plugin(early_stop: bool = True): @parameterize('init_method', ['lazy', 'none', 'colo'])
def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
"""check gemini plugin over model zoo """check gemini plugin over model zoo
Args: Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
""" """
is_support_meta = is_compatible_with_meta()
if not is_support_meta and init_method == 'lazy':
return
from colossalai.utils.model.experimental import LazyInitContext
passed_models = [] passed_models = []
failed_info = {} # (model_name, error) pair failed_info = {} # (model_name, error) pair
@ -40,9 +50,24 @@ def check_gemini_plugin(early_stop: bool = True):
]: ]:
continue continue
if init_method == 'lazy' and name in [
'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3',
'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0',
'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf',
'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s'
]:
continue
try: try:
if init_method == 'colo':
ctx = ColoInitContext()
elif init_method == 'lazy':
ctx = LazyInitContext()
else:
ctx = nullcontext()
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
with ctx:
model = model_fn() model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
@ -76,6 +101,7 @@ def check_gemini_plugin(early_stop: bool = True):
torch.cuda.empty_cache() torch.cuda.empty_cache()
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(f'Init method: {init_method}')
print(f'Passed models({len(passed_models)}): {passed_models}\n\n') print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])