mirror of https://github.com/hpcaitech/ColossalAI
[lazyinit] combine lazy tensor with dtensor (#3204)
* [lazyinit] lazy tensor add distribute * [lazyinit] refactor distribute * [lazyinit] add test dist lazy init * [lazyinit] add verbose info for dist lazy init * [lazyinit] fix rnn flatten weight op * [lazyinit] polish test * [lazyinit] polish test * [lazyinit] fix lazy tensor data setter * [lazyinit] polish test * [lazyinit] fix clean * [lazyinit] make materialize inplace * [lazyinit] refactor materialize * [lazyinit] refactor test distribute * [lazyinit] fix requires_grad * [lazyinit] fix tolist after materialization * [lazyinit] refactor distribute module * [lazyinit] polish docstr * [lazyinit] polish lazy init context * [lazyinit] temporarily skip test * [lazyinit] polish test * [lazyinit] add docstrpull/3213/head
parent
189347963a
commit
f8289d4221
|
@ -1,11 +1,15 @@
|
||||||
from typing import Callable, List, Optional, Union
|
from types import MethodType
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.fx.profiler.tensor import MetaTensor
|
from colossalai.fx.profiler.tensor import MetaTensor
|
||||||
|
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||||
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
|
|
||||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||||
_NORMAL_FACTORY = [
|
_NORMAL_FACTORY = [
|
||||||
|
@ -30,6 +34,11 @@ _NO_META_FACTORY = [
|
||||||
|
|
||||||
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
|
||||||
|
|
||||||
|
# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
|
||||||
|
# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
|
||||||
|
# These ops cannot be unwrapped using .data
|
||||||
|
_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
|
||||||
|
|
||||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||||
'FloatTensor': torch.float,
|
'FloatTensor': torch.float,
|
||||||
'DoubleTensor': torch.double,
|
'DoubleTensor': torch.double,
|
||||||
|
@ -43,6 +52,8 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
|
||||||
'BoolTensor': torch.bool,
|
'BoolTensor': torch.bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_EMPTY_DATA = torch.empty(0)
|
||||||
|
|
||||||
|
|
||||||
class _MyTensor(Tensor):
|
class _MyTensor(Tensor):
|
||||||
"""This class is only for correctness verification.
|
"""This class is only for correctness verification.
|
||||||
|
@ -64,6 +75,29 @@ class _MyTensor(Tensor):
|
||||||
return super().__torch_function__(func, types, args, kwargs)
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert a lazy tensor's class to target's class, with target's data.
|
||||||
|
|
||||||
|
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
|
||||||
|
If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (LazyTensor): the LazyTensor to be converted
|
||||||
|
target (torch.Tensor): target tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the converted tensor
|
||||||
|
"""
|
||||||
|
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
|
||||||
|
tensor.__class__ = cls_to_become
|
||||||
|
tensor.data = target
|
||||||
|
tensor.requires_grad = target.requires_grad
|
||||||
|
# subclass of torch.Tensor does not have tolist() method
|
||||||
|
# overwrite this method after materialization or distribution
|
||||||
|
tensor.tolist = MethodType(torch.Tensor.tolist, target)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class LazyTensor(torch.Tensor):
|
class LazyTensor(torch.Tensor):
|
||||||
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
|
"""A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
|
||||||
|
|
||||||
|
@ -112,14 +146,8 @@ class LazyTensor(torch.Tensor):
|
||||||
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
elem = func(*args, **{**kwargs, 'device': 'meta'})
|
||||||
meta_data = MetaTensor(elem, fake_device=device)
|
meta_data = MetaTensor(elem, fake_device=device)
|
||||||
elem = meta_data._tensor
|
elem = meta_data._tensor
|
||||||
r = torch.Tensor._make_wrapper_subclass(cls,
|
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||||||
elem.size(),
|
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
|
||||||
strides=elem.stride(),
|
|
||||||
storage_offset=elem.storage_offset(),
|
|
||||||
dtype=elem.dtype,
|
|
||||||
layout=elem.layout,
|
|
||||||
device=elem.device,
|
|
||||||
requires_grad=elem.requires_grad)
|
|
||||||
r._meta_data = meta_data
|
r._meta_data = meta_data
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
@ -129,15 +157,28 @@ class LazyTensor(torch.Tensor):
|
||||||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||||
|
|
||||||
def materialize(self) -> torch.Tensor:
|
def materialize(self) -> torch.Tensor:
|
||||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor``.
|
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The materialized tensor.
|
torch.Tensor: The materialized tensor (self).
|
||||||
"""
|
"""
|
||||||
target = self._materialize_data()
|
target = self._materialize_data()
|
||||||
if isinstance(self, nn.Parameter):
|
self.clean()
|
||||||
target = nn.Parameter(target, requires_grad=self.requires_grad)
|
return _convert_cls(self, target)
|
||||||
return target
|
|
||||||
|
def distribute(self, layout: Layout) -> torch.Tensor:
|
||||||
|
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layout (Layout): Distribution layout.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The distributed tensor (self).
|
||||||
|
"""
|
||||||
|
target = self._materialize_data()
|
||||||
|
self.clean()
|
||||||
|
local_tensor = DTensor(target, layout).local_tensor
|
||||||
|
return _convert_cls(self, local_tensor)
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
||||||
|
@ -216,6 +257,8 @@ class LazyTensor(torch.Tensor):
|
||||||
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
|
||||||
or func.__name__ == "__setitem__")
|
or func.__name__ == "__setitem__")
|
||||||
|
|
||||||
|
is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
|
||||||
|
|
||||||
if isinstance(func, torch._C.ScriptMethod):
|
if isinstance(func, torch._C.ScriptMethod):
|
||||||
# FIXME(ver217): torch script functions are not verified
|
# FIXME(ver217): torch script functions are not verified
|
||||||
|
|
||||||
|
@ -239,10 +282,10 @@ class LazyTensor(torch.Tensor):
|
||||||
if isinstance(x, LazyTensor):
|
if isinstance(x, LazyTensor):
|
||||||
if x._materialized_data is not None:
|
if x._materialized_data is not None:
|
||||||
# for early materialized tensor, use its materialized data directly
|
# for early materialized tensor, use its materialized data directly
|
||||||
return x._materialized_data.data
|
return x._materialized_data if is_change_meta_op else x._materialized_data.data
|
||||||
t = x if is_inplace else x.clone()
|
t = x if is_inplace else x.clone()
|
||||||
t._op_buffer.append((func, args, kwargs))
|
t._op_buffer.append((func, args, kwargs))
|
||||||
meta = x._meta_data.data
|
meta = x._meta_data if is_change_meta_op else x._meta_data.data
|
||||||
meta_to_lazy[meta] = t
|
meta_to_lazy[meta] = t
|
||||||
return meta
|
return meta
|
||||||
return x
|
return x
|
||||||
|
@ -290,13 +333,36 @@ class LazyTensor(torch.Tensor):
|
||||||
|
|
||||||
@data.setter
|
@data.setter
|
||||||
def data(self, other: 'LazyTensor'):
|
def data(self, other: 'LazyTensor'):
|
||||||
|
"""This is sightly different from oringinal `data` setter.
|
||||||
|
|
||||||
|
E.g.:
|
||||||
|
>>> a = torch.randn(3, 3) # a is a Tensor
|
||||||
|
>>> b = torch.rand(2, 2)
|
||||||
|
>>> a.data = b
|
||||||
|
>>> b.add_(1) # this will affect a
|
||||||
|
>>> x = torch.randn(3, 3) # x is a LazyTensor
|
||||||
|
>>> y = torch.rand(2, 2) # y is a LazyTensor
|
||||||
|
>>> x.data = y
|
||||||
|
>>> y.add_(1) # this will not affect x
|
||||||
|
|
||||||
|
"""
|
||||||
if other is self:
|
if other is self:
|
||||||
return
|
return
|
||||||
# TODO(ver217): to avoid infinity recursion, do early materialization
|
|
||||||
self._materialized_data = other._materialize_data()
|
self._op_buffer.append(other._factory_method)
|
||||||
|
|
||||||
|
def replace(x):
|
||||||
|
if x is other:
|
||||||
|
return self
|
||||||
|
return x
|
||||||
|
|
||||||
|
for func, args, kwargs in other._op_buffer:
|
||||||
|
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
|
||||||
|
|
||||||
def tolist(self) -> list:
|
def tolist(self) -> list:
|
||||||
t = self.materialize()
|
# Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
|
||||||
|
# And subclass of torch.Tensor does not have tolist() method
|
||||||
|
t = self._materialize_data()
|
||||||
return t.tolist()
|
return t.tolist()
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
|
@ -421,73 +487,86 @@ class LazyInitContext:
|
||||||
setattr(torch, name, orig)
|
setattr(torch, name, orig)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def materialize(module: torch.nn.Module, verbose: bool = False):
|
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``.
|
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (torch.nn.Module): Target ``nn.Module``
|
module (nn.Module): Target ``nn.Module``
|
||||||
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
|
verbose (bool): Whether to print lazy initialization rate. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def apply_fn(name: str, p: LazyTensor):
|
||||||
|
p.materialize()
|
||||||
|
|
||||||
|
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
|
||||||
|
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): Target ``nn.Module``
|
||||||
|
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
|
||||||
|
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def apply_fn(name: str, p: LazyTensor):
|
||||||
|
p.distribute(layout_dict[name])
|
||||||
|
|
||||||
|
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_to_lazy_module(module: nn.Module,
|
||||||
|
apply_fn: Callable[[str, torch.Tensor], None],
|
||||||
|
verbose: bool = False) -> nn.Module:
|
||||||
if verbose:
|
if verbose:
|
||||||
|
# verbose info
|
||||||
param_cnt = 0
|
param_cnt = 0
|
||||||
param_lazy_cnt = 0
|
param_lazy_cnt = 0
|
||||||
buf_cnt = 0
|
buf_cnt = 0
|
||||||
buf_lazy_cnt = 0
|
buf_lazy_cnt = 0
|
||||||
|
total_numel = 0
|
||||||
non_lazy_numel = 0
|
non_lazy_numel = 0
|
||||||
|
|
||||||
# do post cleaning to handle shared parameter
|
for name, p in module.named_parameters():
|
||||||
visited_lazy_tensors: List[LazyTensor] = []
|
|
||||||
# handle shared module
|
|
||||||
visited_modules = set()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def init_recursively(module: nn.Module):
|
|
||||||
nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel
|
|
||||||
# recursively initialize the module
|
|
||||||
for mod in module.children():
|
|
||||||
if id(mod) not in visited_modules:
|
|
||||||
visited_modules.add(id(mod))
|
|
||||||
init_recursively(mod)
|
|
||||||
|
|
||||||
# initialize tensors directly attached to the current module
|
|
||||||
for name, param in module.named_parameters(recurse=False):
|
|
||||||
if verbose:
|
if verbose:
|
||||||
param_cnt += 1
|
param_cnt += 1
|
||||||
if getattr(param, '_materialized_data', False) is None:
|
total_numel += p.numel()
|
||||||
|
if getattr(p, '_materialized_data', False) is None:
|
||||||
# if no _materialized_data attr, the tensor is not lazy
|
# if no _materialized_data attr, the tensor is not lazy
|
||||||
param_lazy_cnt += 1
|
param_lazy_cnt += 1
|
||||||
else:
|
else:
|
||||||
non_lazy_numel += param.numel()
|
non_lazy_numel += p.numel()
|
||||||
if hasattr(param, 'materialize'):
|
if isinstance(p, LazyTensor):
|
||||||
# TODO(ver217): apex layers cannot be captured
|
apply_fn(name, p)
|
||||||
visited_lazy_tensors.append(param)
|
|
||||||
setattr(module, name, param.materialize())
|
|
||||||
|
|
||||||
for name, buf in module.named_buffers(recurse=False):
|
for name, buf in module.named_buffers():
|
||||||
if verbose:
|
if verbose:
|
||||||
buf_cnt += 1
|
buf_cnt += 1
|
||||||
|
total_numel += buf.numel()
|
||||||
if getattr(buf, "_materialized_data", False) is None:
|
if getattr(buf, "_materialized_data", False) is None:
|
||||||
# if no _materialized_data attr, the tensor is not lazy
|
# if no _materialized_data attr, the tensor is not lazy
|
||||||
buf_lazy_cnt += 1
|
buf_lazy_cnt += 1
|
||||||
else:
|
else:
|
||||||
non_lazy_numel += buf.numel()
|
non_lazy_numel += buf.numel()
|
||||||
if hasattr(buf, 'materialize'):
|
if isinstance(buf, LazyTensor):
|
||||||
# TODO(ver217): apex layers cannot be captured
|
apply_fn(name, buf)
|
||||||
visited_lazy_tensors.append(buf)
|
|
||||||
setattr(module, name, buf.materialize())
|
|
||||||
|
|
||||||
init_recursively(module)
|
|
||||||
|
|
||||||
for t in visited_lazy_tensors:
|
|
||||||
t.clean()
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
|
||||||
print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
_print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
|
||||||
print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)')
|
_print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
|
||||||
|
_print_rank_0(
|
||||||
|
f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _print_rank_0(*args, **kwargs):
|
||||||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
print(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _is_int_tuple(args) -> bool:
|
def _is_int_tuple(args) -> bool:
|
||||||
if not isinstance(args, tuple):
|
if not isinstance(args, tuple):
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
|
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.utils.common import print_rank_0
|
||||||
|
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
# from utils import assert_dist_model_equal, set_seed
|
||||||
|
|
||||||
|
|
||||||
|
def find_shard_dim(shape: torch.Size) -> Optional[int]:
|
||||||
|
for dim, size in enumerate(shape):
|
||||||
|
if size % 2 == 0:
|
||||||
|
return dim
|
||||||
|
|
||||||
|
|
||||||
|
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
|
||||||
|
shard_dim = find_shard_dim(original_tensor.shape)
|
||||||
|
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
|
||||||
|
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
|
||||||
|
layout = Layout(device_mesh=device_mesh,
|
||||||
|
device_type=torch.device('cuda'),
|
||||||
|
sharding_spec=target_sharding_spec,
|
||||||
|
entire_shape=original_tensor.shape)
|
||||||
|
return layout
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_name(prefix: str, name: str) -> str:
|
||||||
|
return f'{prefix}.{name}'.lstrip('.')
|
||||||
|
|
||||||
|
|
||||||
|
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
|
||||||
|
layout_dict = {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate_recursively(module: nn.Module, prefix: str = ''):
|
||||||
|
# recursively initialize the module
|
||||||
|
for name, mod in module.named_children():
|
||||||
|
generate_recursively(mod, prefix=_get_current_name(prefix, name))
|
||||||
|
|
||||||
|
# initialize tensors directly attached to the current module
|
||||||
|
for name, param in module.named_parameters(recurse=False):
|
||||||
|
if isinstance(param, LazyTensor):
|
||||||
|
layout = make_layout(device_mesh, param)
|
||||||
|
layout_dict[_get_current_name(prefix, name)] = layout
|
||||||
|
|
||||||
|
for name, buf in module.named_buffers(recurse=False):
|
||||||
|
if isinstance(buf, LazyTensor):
|
||||||
|
layout = make_layout(device_mesh, buf)
|
||||||
|
layout_dict[_get_current_name(prefix, name)] = layout
|
||||||
|
|
||||||
|
generate_recursively(model)
|
||||||
|
|
||||||
|
return layout_dict
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
|
||||||
|
def run_dist_lazy_init(subset, seed: int = 42):
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry(subset)
|
||||||
|
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||||
|
# FIXME(ver217): uncomment this line
|
||||||
|
# _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||||
|
# LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||||
|
|
||||||
|
for name, entry in sub_model_zoo.items():
|
||||||
|
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||||
|
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
|
||||||
|
continue
|
||||||
|
print_rank_0(name)
|
||||||
|
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
|
||||||
|
ctx = LazyInitContext(tensor_cls=_MyTensor)
|
||||||
|
with ctx:
|
||||||
|
model = model_fn()
|
||||||
|
ctx = LazyInitContext()
|
||||||
|
with ctx:
|
||||||
|
deferred_model = model_fn()
|
||||||
|
layout_dict = generate_layout_dict(deferred_model, device_mesh)
|
||||||
|
ctx.distribute(deferred_model, layout_dict, verbose=True)
|
||||||
|
# FIXME(ver217): uncomment this line
|
||||||
|
# assert_dist_model_equal(model, deferred_model, layout_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port) -> None:
|
||||||
|
colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port)
|
||||||
|
run_dist_lazy_init()
|
||||||
|
|
||||||
|
|
||||||
|
# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
|
||||||
|
@pytest.mark.skip
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_dist_lazy_init():
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dist_lazy_init()
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Optional, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||||
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
|
||||||
from tests.kit.model_zoo.registry import ModelAttribute
|
from tests.kit.model_zoo.registry import ModelAttribute
|
||||||
|
|
||||||
|
@ -67,3 +68,18 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
||||||
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
|
assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'{model.__class__.__name__} pass')
|
print(f'{model.__class__.__name__} pass')
|
||||||
|
|
||||||
|
|
||||||
|
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
|
||||||
|
state = model.state_dict()
|
||||||
|
distributed_state = distributed_model.state_dict()
|
||||||
|
|
||||||
|
assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}'
|
||||||
|
|
||||||
|
for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()):
|
||||||
|
assert n1 == n2
|
||||||
|
t1 = t1.cuda()
|
||||||
|
t2 = t2.cuda()
|
||||||
|
if n2 in layout_dict:
|
||||||
|
t2 = to_global(t2, layout_dict[n2])
|
||||||
|
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||||
|
|
Loading…
Reference in New Issue