mirror of https://github.com/hpcaitech/ColossalAI
[lazyinit] combine lazy tensor with dtensor (#3204)
* [lazyinit] lazy tensor add distribute * [lazyinit] refactor distribute * [lazyinit] add test dist lazy init * [lazyinit] add verbose info for dist lazy init * [lazyinit] fix rnn flatten weight op * [lazyinit] polish test * [lazyinit] polish test * [lazyinit] fix lazy tensor data setter * [lazyinit] polish test * [lazyinit] fix clean * [lazyinit] make materialize inplace * [lazyinit] refactor materialize * [lazyinit] refactor test distribute * [lazyinit] fix requires_grad * [lazyinit] fix tolist after materialization * [lazyinit] refactor distribute module * [lazyinit] polish docstr * [lazyinit] polish lazy init context * [lazyinit] temporarily skip test * [lazyinit] polish test * [lazyinit] add docstrpull/3213/head
parent
189347963a
commit
f8289d4221
|
@ -1,11 +1,15 @@
|
|||
from typing import Callable, List, Optional, Union
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
|
@ -30,6 +34,11 @@ _NO_META_FACTORY = [
|
|||
|
||||
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||||
|
||||
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
|
||||
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
|
||||
# These ops cannot be unwrapped using .data
|
||||
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
'FloatTensor': torch.float,
|
||||
'DoubleTensor': torch.double,
|
||||
|
@ -43,6 +52,8 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
|
|||
'BoolTensor': torch.bool,
|
||||
}
|
||||
|
||||
_EMPTY_DATA = torch.empty(0)
|
||||
|
||||
|
||||
class _MyTensor(Tensor):
|
||||
"""This class is only for correctness verification.
|
||||
|
@ -64,6 +75,29 @@ class _MyTensor(Tensor):
|
|||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||
|
||||
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
||||
If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually.
|
||||
|
||||
Args:
|
||||
tensor (LazyTensor): the LazyTensor to be converted
|
||||
target (torch.Tensor): target tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the converted tensor
|
||||
"""
|
||||
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
|
||||
tensor.__class__ = cls_to_become
|
||||
tensor.data = target
|
||||
tensor.requires_grad = target.requires_grad
|
||||
# subclass of torch.Tensor does not have tolist() method
|
||||
# overwrite this method after materialization or distribution
|
||||
tensor.tolist = MethodType(torch.Tensor.tolist, target)
|
||||
return tensor
|
||||
|
||||
|
||||
class LazyTensor(torch.Tensor):
|
||||
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
|
||||
|
||||
|
@ -112,14 +146,8 @@ class LazyTensor(torch.Tensor):
|
|||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||
meta_data = MetaTensor(elem, fake_device=device)
|
||||
elem = meta_data._tensor
|
||||
r = torch.Tensor._make_wrapper_subclass(cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
requires_grad=elem.requires_grad)
|
||||
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||||
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
|
||||
r._meta_data = meta_data
|
||||
return r
|
||||
|
||||
|
@ -129,15 +157,28 @@ class LazyTensor(torch.Tensor):
|
|||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||
|
||||
def materialize(self) -> torch.Tensor:
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The materialized tensor.
|
||||
torch.Tensor: The materialized tensor (self).
|
||||
"""
|
||||
target = self._materialize_data()
|
||||
if isinstance(self, nn.Parameter):
|
||||
target = nn.Parameter(target, requires_grad=self.requires_grad)
|
||||
return target
|
||||
self.clean()
|
||||
return _convert_cls(self, target)
|
||||
|
||||
def distribute(self, layout: Layout) -> torch.Tensor:
|
||||
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
|
||||
|
||||
Args:
|
||||
layout (Layout): Distribution layout.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor (self).
|
||||
"""
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
local_tensor = DTensor(target, layout).local_tensor
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
||||
|
@ -216,6 +257,8 @@ class LazyTensor(torch.Tensor):
|
|||
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||||
or func.__name__ == "__setitem__")
|
||||
|
||||
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
|
||||
|
||||
if isinstance(func, torch._C.ScriptMethod):
|
||||
# FIXME(ver217): torch script functions are not verified
|
||||
|
||||
|
@ -239,10 +282,10 @@ class LazyTensor(torch.Tensor):
|
|||
if isinstance(x, LazyTensor):
|
||||
if x._materialized_data is not None:
|
||||
# for early materialized tensor, use its materialized data directly
|
||||
return x._materialized_data.data
|
||||
return x._materialized_data if is_change_meta_op else x._materialized_data.data
|
||||
t = x if is_inplace else x.clone()
|
||||
t._op_buffer.append((func, args, kwargs))
|
||||
meta = x._meta_data.data
|
||||
meta = x._meta_data if is_change_meta_op else x._meta_data.data
|
||||
meta_to_lazy[meta] = t
|
||||
return meta
|
||||
return x
|
||||
|
@ -290,13 +333,36 @@ class LazyTensor(torch.Tensor):
|
|||
|
||||
@data.setter
|
||||
def data(self, other: 'LazyTensor'):
|
||||
"""This is sightly different from oringinal `data` setter.
|
||||
|
||||
E.g.:
|
||||
>>> a = torch.randn(3, 3) # a is a Tensor
|
||||
>>> b = torch.rand(2, 2)
|
||||
>>> a.data = b
|
||||
>>> b.add_(1) # this will affect a
|
||||
>>> x = torch.randn(3, 3) # x is a LazyTensor
|
||||
>>> y = torch.rand(2, 2) # y is a LazyTensor
|
||||
>>> x.data = y
|
||||
>>> y.add_(1) # this will not affect x
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return
|
||||
# TODO(ver217): to avoid infinity recursion, do early materialization
|
||||
self._materialized_data = other._materialize_data()
|
||||
|
||||
self._op_buffer.append(other._factory_method)
|
||||
|
||||
def replace(x):
|
||||
if x is other:
|
||||
return self
|
||||
return x
|
||||
|
||||
for func, args, kwargs in other._op_buffer:
|
||||
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
|
||||
|
||||
def tolist(self) -> list:
|
||||
t = self.materialize()
|
||||
# Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
|
||||
# And subclass of torch.Tensor does not have tolist() method
|
||||
t = self._materialize_data()
|
||||
return t.tolist()
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -421,71 +487,84 @@ class LazyInitContext:
|
|||
setattr(torch, name, orig)
|
||||
|
||||
@staticmethod
|
||||
def materialize(module: torch.nn.Module, verbose: bool = False):
|
||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``.
|
||||
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Target ``nn.Module``
|
||||
module (nn.Module): Target ``nn.Module``
|
||||
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
|
||||
"""
|
||||
|
||||
def apply_fn(name: str, p: LazyTensor):
|
||||
p.materialize()
|
||||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
@staticmethod
|
||||
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
|
||||
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Target ``nn.Module``
|
||||
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
|
||||
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
|
||||
"""
|
||||
|
||||
def apply_fn(name: str, p: LazyTensor):
|
||||
p.distribute(layout_dict[name])
|
||||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
|
||||
def _apply_to_lazy_module(module: nn.Module,
|
||||
apply_fn: Callable[[str, torch.Tensor], None],
|
||||
verbose: bool = False) -> nn.Module:
|
||||
if verbose:
|
||||
# verbose info
|
||||
param_cnt = 0
|
||||
param_lazy_cnt = 0
|
||||
buf_cnt = 0
|
||||
buf_lazy_cnt = 0
|
||||
total_numel = 0
|
||||
non_lazy_numel = 0
|
||||
|
||||
for name, p in module.named_parameters():
|
||||
if verbose:
|
||||
param_cnt = 0
|
||||
param_lazy_cnt = 0
|
||||
buf_cnt = 0
|
||||
buf_lazy_cnt = 0
|
||||
non_lazy_numel = 0
|
||||
|
||||
# do post cleaning to handle shared parameter
|
||||
visited_lazy_tensors: List[LazyTensor] = []
|
||||
# handle shared module
|
||||
visited_modules = set()
|
||||
|
||||
@torch.no_grad()
|
||||
def init_recursively(module: nn.Module):
|
||||
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
|
||||
# recursively initialize the module
|
||||
for mod in module.children():
|
||||
if id(mod) not in visited_modules:
|
||||
visited_modules.add(id(mod))
|
||||
init_recursively(mod)
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if verbose:
|
||||
param_cnt += 1
|
||||
if getattr(param, '_materialized_data', False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
param_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += param.numel()
|
||||
if hasattr(param, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(param)
|
||||
setattr(module, name, param.materialize())
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if verbose:
|
||||
buf_cnt += 1
|
||||
if getattr(buf, "_materialized_data", False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
buf_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += buf.numel()
|
||||
if hasattr(buf, 'materialize'):
|
||||
# TODO(ver217): apex layers cannot be captured
|
||||
visited_lazy_tensors.append(buf)
|
||||
setattr(module, name, buf.materialize())
|
||||
|
||||
init_recursively(module)
|
||||
|
||||
for t in visited_lazy_tensors:
|
||||
t.clean()
|
||||
param_cnt += 1
|
||||
total_numel += p.numel()
|
||||
if getattr(p, '_materialized_data', False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
param_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += p.numel()
|
||||
if isinstance(p, LazyTensor):
|
||||
apply_fn(name, p)
|
||||
|
||||
for name, buf in module.named_buffers():
|
||||
if verbose:
|
||||
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
|
||||
return module
|
||||
buf_cnt += 1
|
||||
total_numel += buf.numel()
|
||||
if getattr(buf, "_materialized_data", False) is None:
|
||||
# if no _materialized_data attr, the tensor is not lazy
|
||||
buf_lazy_cnt += 1
|
||||
else:
|
||||
non_lazy_numel += buf.numel()
|
||||
if isinstance(buf, LazyTensor):
|
||||
apply_fn(name, buf)
|
||||
|
||||
if verbose:
|
||||
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
|
||||
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||
_print_rank_0(
|
||||
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _print_rank_0(*args, **kwargs):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_int_tuple(args) -> bool:
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.common import print_rank_0
|
||||
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
# from utils import assert_dist_model_equal, set_seed
|
||||
|
||||
|
||||
def find_shard_dim(shape: torch.Size) -> Optional[int]:
|
||||
for dim, size in enumerate(shape):
|
||||
if size % 2 == 0:
|
||||
return dim
|
||||
|
||||
|
||||
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
|
||||
shard_dim = find_shard_dim(original_tensor.shape)
|
||||
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
return layout
|
||||
|
||||
|
||||
def _get_current_name(prefix: str, name: str) -> str:
|
||||
return f'{prefix}.{name}'.lstrip('.')
|
||||
|
||||
|
||||
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
|
||||
layout_dict = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_recursively(module: nn.Module, prefix: str = ''):
|
||||
# recursively initialize the module
|
||||
for name, mod in module.named_children():
|
||||
generate_recursively(mod, prefix=_get_current_name(prefix, name))
|
||||
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if isinstance(param, LazyTensor):
|
||||
layout = make_layout(device_mesh, param)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if isinstance(buf, LazyTensor):
|
||||
layout = make_layout(device_mesh, buf)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
|
||||
generate_recursively(model)
|
||||
|
||||
return layout_dict
|
||||
|
||||
|
||||
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
|
||||
def run_dist_lazy_init(subset, seed: int = 42):
|
||||
sub_model_zoo = model_zoo.get_sub_registry(subset)
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
# FIXME(ver217): uncomment this line
|
||||
# _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
# LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
|
||||
continue
|
||||
print_rank_0(name)
|
||||
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
|
||||
ctx = LazyInitContext(tensor_cls=_MyTensor)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
deferred_model = model_fn()
|
||||
layout_dict = generate_layout_dict(deferred_model, device_mesh)
|
||||
ctx.distribute(deferred_model, layout_dict, verbose=True)
|
||||
# FIXME(ver217): uncomment this line
|
||||
# assert_dist_model_equal(model, deferred_model, layout_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port) -> None:
|
||||
colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port)
|
||||
run_dist_lazy_init()
|
||||
|
||||
|
||||
# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_lazy_init():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_lazy_init()
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Optional, Tuple
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
||||
from tests.kit.model_zoo.registry import ModelAttribute
|
||||
|
||||
|
@ -67,3 +68,18 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
|||
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
|
||||
if verbose:
|
||||
print(f'{model.__class__.__name__} pass')
|
||||
|
||||
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
|
||||
state = model.state_dict()
|
||||
distributed_state = distributed_model.state_dict()
|
||||
|
||||
assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}'
|
||||
|
||||
for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()):
|
||||
assert n1 == n2
|
||||
t1 = t1.cuda()
|
||||
t2 = t2.cuda()
|
||||
if n2 in layout_dict:
|
||||
t2 = to_global(t2, layout_dict[n2])
|
||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||
|
|
Loading…
Reference in New Issue