[lazyinit] combine lazy tensor with dtensor (#3204)

* [lazyinit] lazy tensor add distribute

* [lazyinit] refactor distribute

* [lazyinit] add test dist lazy init

* [lazyinit] add verbose info for dist lazy init

* [lazyinit] fix rnn flatten weight op

* [lazyinit] polish test

* [lazyinit] polish test

* [lazyinit] fix lazy tensor data setter

* [lazyinit] polish test

* [lazyinit] fix clean

* [lazyinit] make materialize inplace

* [lazyinit] refactor materialize

* [lazyinit] refactor test distribute

* [lazyinit] fix requires_grad

* [lazyinit] fix tolist after materialization

* [lazyinit] refactor distribute module

* [lazyinit] polish docstr

* [lazyinit] polish lazy init context

* [lazyinit] temporarily skip test

* [lazyinit] polish test

* [lazyinit] add docstr
pull/3213/head
ver217 2 years ago committed by GitHub
parent 189347963a
commit f8289d4221
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,11 +1,15 @@
from typing import Callable, List, Optional, Union
from types import MethodType
from typing import Callable, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.utils._pytree import tree_map
from colossalai.fx.profiler.tensor import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
@ -30,6 +34,11 @@ _NO_META_FACTORY = [
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
_LEGACY_TENSOR_CONSTRUCTOR = {
'FloatTensor': torch.float,
'DoubleTensor': torch.double,
@ -43,6 +52,8 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
'BoolTensor': torch.bool,
}
_EMPTY_DATA = torch.empty(0)
class _MyTensor(Tensor):
"""This class is only for correctness verification.
@ -64,6 +75,29 @@ class _MyTensor(Tensor):
return super().__torch_function__(func, types, args, kwargs)
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually.
Args:
tensor (LazyTensor): the LazyTensor to be converted
target (torch.Tensor): target tensor
Returns:
torch.Tensor: the converted tensor
"""
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
tensor.data = target
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(torch.Tensor.tolist, target)
return tensor
class LazyTensor(torch.Tensor):
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
@ -112,14 +146,8 @@ class LazyTensor(torch.Tensor):
elem = func(*args, **{**kwargs, 'device': 'meta'})
meta_data = MetaTensor(elem, fake_device=device)
elem = meta_data._tensor
r = torch.Tensor._make_wrapper_subclass(cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad)
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
r._meta_data = meta_data
return r
@ -129,15 +157,28 @@ class LazyTensor(torch.Tensor):
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
Returns:
torch.Tensor: The materialized tensor.
torch.Tensor: The materialized tensor (self).
"""
target = self._materialize_data()
if isinstance(self, nn.Parameter):
target = nn.Parameter(target, requires_grad=self.requires_grad)
return target
self.clean()
return _convert_cls(self, target)
def distribute(self, layout: Layout) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args:
layout (Layout): Distribution layout.
Returns:
torch.Tensor: The distributed tensor (self).
"""
target = self._materialize_data()
self.clean()
local_tensor = DTensor(target, layout).local_tensor
return _convert_cls(self, local_tensor)
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
@ -216,6 +257,8 @@ class LazyTensor(torch.Tensor):
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
or func.__name__ == "__setitem__")
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
if isinstance(func, torch._C.ScriptMethod):
# FIXME(ver217): torch script functions are not verified
@ -239,10 +282,10 @@ class LazyTensor(torch.Tensor):
if isinstance(x, LazyTensor):
if x._materialized_data is not None:
# for early materialized tensor, use its materialized data directly
return x._materialized_data.data
return x._materialized_data if is_change_meta_op else x._materialized_data.data
t = x if is_inplace else x.clone()
t._op_buffer.append((func, args, kwargs))
meta = x._meta_data.data
meta = x._meta_data if is_change_meta_op else x._meta_data.data
meta_to_lazy[meta] = t
return meta
return x
@ -290,13 +333,36 @@ class LazyTensor(torch.Tensor):
@data.setter
def data(self, other: 'LazyTensor'):
"""This is sightly different from oringinal `data` setter.
E.g.:
>>> a = torch.randn(3, 3) # a is a Tensor
>>> b = torch.rand(2, 2)
>>> a.data = b
>>> b.add_(1) # this will affect a
>>> x = torch.randn(3, 3) # x is a LazyTensor
>>> y = torch.rand(2, 2) # y is a LazyTensor
>>> x.data = y
>>> y.add_(1) # this will not affect x
"""
if other is self:
return
# TODO(ver217): to avoid infinity recursion, do early materialization
self._materialized_data = other._materialize_data()
self._op_buffer.append(other._factory_method)
def replace(x):
if x is other:
return self
return x
for func, args, kwargs in other._op_buffer:
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
def tolist(self) -> list:
t = self.materialize()
# Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
# And subclass of torch.Tensor does not have tolist() method
t = self._materialize_data()
return t.tolist()
def __hash__(self):
@ -421,71 +487,84 @@ class LazyInitContext:
setattr(torch, name, orig)
@staticmethod
def materialize(module: torch.nn.Module, verbose: bool = False):
"""Initialize all ``nn.Parameter`` from ``LazyTensor``.
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (torch.nn.Module): Target ``nn.Module``
module (nn.Module): Target ``nn.Module``
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
"""
if verbose:
param_cnt = 0
param_lazy_cnt = 0
buf_cnt = 0
buf_lazy_cnt = 0
non_lazy_numel = 0
# do post cleaning to handle shared parameter
visited_lazy_tensors: List[LazyTensor] = []
# handle shared module
visited_modules = set()
@torch.no_grad()
def init_recursively(module: nn.Module):
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
# recursively initialize the module
for mod in module.children():
if id(mod) not in visited_modules:
visited_modules.add(id(mod))
init_recursively(mod)
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
if verbose:
param_cnt += 1
if getattr(param, '_materialized_data', False) is None:
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1
else:
non_lazy_numel += param.numel()
if hasattr(param, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(param)
setattr(module, name, param.materialize())
for name, buf in module.named_buffers(recurse=False):
if verbose:
buf_cnt += 1
if getattr(buf, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy
buf_lazy_cnt += 1
else:
non_lazy_numel += buf.numel()
if hasattr(buf, 'materialize'):
# TODO(ver217): apex layers cannot be captured
visited_lazy_tensors.append(buf)
setattr(module, name, buf.materialize())
init_recursively(module)
def apply_fn(name: str, p: LazyTensor):
p.materialize()
return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (nn.Module): Target ``nn.Module``
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
"""
def apply_fn(name: str, p: LazyTensor):
p.distribute(layout_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose)
for t in visited_lazy_tensors:
t.clean()
def _apply_to_lazy_module(module: nn.Module,
apply_fn: Callable[[str, torch.Tensor], None],
verbose: bool = False) -> nn.Module:
if verbose:
# verbose info
param_cnt = 0
param_lazy_cnt = 0
buf_cnt = 0
buf_lazy_cnt = 0
total_numel = 0
non_lazy_numel = 0
for name, p in module.named_parameters():
if verbose:
param_cnt += 1
total_numel += p.numel()
if getattr(p, '_materialized_data', False) is None:
# if no _materialized_data attr, the tensor is not lazy
param_lazy_cnt += 1
else:
non_lazy_numel += p.numel()
if isinstance(p, LazyTensor):
apply_fn(name, p)
for name, buf in module.named_buffers():
if verbose:
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
return module
buf_cnt += 1
total_numel += buf.numel()
if getattr(buf, "_materialized_data", False) is None:
# if no _materialized_data attr, the tensor is not lazy
buf_lazy_cnt += 1
else:
non_lazy_numel += buf.numel()
if isinstance(buf, LazyTensor):
apply_fn(name, buf)
if verbose:
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
_print_rank_0(
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
return module
def _print_rank_0(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
def _is_int_tuple(args) -> bool:

@ -0,0 +1,110 @@
from functools import partial
from typing import Optional
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.common import print_rank_0
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
from tests.kit.model_zoo import model_zoo
# from utils import assert_dist_model_equal, set_seed
def find_shard_dim(shape: torch.Size) -> Optional[int]:
for dim, size in enumerate(shape):
if size % 2 == 0:
return dim
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
shard_dim = find_shard_dim(original_tensor.shape)
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
return layout
def _get_current_name(prefix: str, name: str) -> str:
return f'{prefix}.{name}'.lstrip('.')
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
layout_dict = {}
@torch.no_grad()
def generate_recursively(module: nn.Module, prefix: str = ''):
# recursively initialize the module
for name, mod in module.named_children():
generate_recursively(mod, prefix=_get_current_name(prefix, name))
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
if isinstance(param, LazyTensor):
layout = make_layout(device_mesh, param)
layout_dict[_get_current_name(prefix, name)] = layout
for name, buf in module.named_buffers(recurse=False):
if isinstance(buf, LazyTensor):
layout = make_layout(device_mesh, buf)
layout_dict[_get_current_name(prefix, name)] = layout
generate_recursively(model)
return layout_dict
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
def run_dist_lazy_init(subset, seed: int = 42):
sub_model_zoo = model_zoo.get_sub_registry(subset)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
# FIXME(ver217): uncomment this line
# _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
# LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
continue
print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
ctx = LazyInitContext(tensor_cls=_MyTensor)
with ctx:
model = model_fn()
ctx = LazyInitContext()
with ctx:
deferred_model = model_fn()
layout_dict = generate_layout_dict(deferred_model, device_mesh)
ctx.distribute(deferred_model, layout_dict, verbose=True)
# FIXME(ver217): uncomment this line
# assert_dist_model_equal(model, deferred_model, layout_dict)
def run_dist(rank, world_size, port) -> None:
colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port)
run_dist_lazy_init()
# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
@pytest.mark.skip
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_lazy_init():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_dist_lazy_init()

@ -4,6 +4,7 @@ from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
from colossalai.tensor.d_tensor.layout_converter import to_global
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
from tests.kit.model_zoo.registry import ModelAttribute
@ -67,3 +68,18 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
if verbose:
print(f'{model.__class__.__name__} pass')
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
state = model.state_dict()
distributed_state = distributed_model.state_dict()
assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}'
for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()):
assert n1 == n2
t1 = t1.cuda()
t2 = t2.cuda()
if n2 in layout_dict:
t2 = to_global(t2, layout_dict[n2])
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'

Loading…
Cancel
Save