diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py index 00cb532d9..6427a147a 100644 --- a/colossalai/utils/model/experimental.py +++ b/colossalai/utils/model/experimental.py @@ -1,11 +1,15 @@ -from typing import Callable, List, Optional, Union +from types import MethodType +from typing import Callable, Optional, Union import torch +import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.utils._pytree import tree_map from colossalai.fx.profiler.tensor import MetaTensor +from colossalai.tensor.d_tensor.d_tensor import DTensor +from colossalai.tensor.d_tensor.layout import Layout # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -30,6 +34,11 @@ _NO_META_FACTORY = [ _EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] +# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) +# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. +# These ops cannot be unwrapped using .data +_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__'] + _LEGACY_TENSOR_CONSTRUCTOR = { 'FloatTensor': torch.float, 'DoubleTensor': torch.double, @@ -43,6 +52,8 @@ _LEGACY_TENSOR_CONSTRUCTOR = { 'BoolTensor': torch.bool, } +_EMPTY_DATA = torch.empty(0) + class _MyTensor(Tensor): """This class is only for correctness verification. @@ -64,6 +75,29 @@ class _MyTensor(Tensor): return super().__torch_function__(func, types, args, kwargs) +def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: + """Convert a lazy tensor's class to target's class, with target's data. + + The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. + If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually. + + Args: + tensor (LazyTensor): the LazyTensor to be converted + target (torch.Tensor): target tensor + + Returns: + torch.Tensor: the converted tensor + """ + cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor + tensor.__class__ = cls_to_become + tensor.data = target + tensor.requires_grad = target.requires_grad + # subclass of torch.Tensor does not have tolist() method + # overwrite this method after materialization or distribution + tensor.tolist = MethodType(torch.Tensor.tolist, target) + return tensor + + class LazyTensor(torch.Tensor): """A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf). @@ -112,14 +146,8 @@ class LazyTensor(torch.Tensor): elem = func(*args, **{**kwargs, 'device': 'meta'}) meta_data = MetaTensor(elem, fake_device=device) elem = meta_data._tensor - r = torch.Tensor._make_wrapper_subclass(cls, - elem.size(), - strides=elem.stride(), - storage_offset=elem.storage_offset(), - dtype=elem.dtype, - layout=elem.layout, - device=elem.device, - requires_grad=elem.requires_grad) + # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here + r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r._meta_data = meta_data return r @@ -129,15 +157,28 @@ class LazyTensor(torch.Tensor): self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data def materialize(self) -> torch.Tensor: - """Materialize the ``LazyTensor`` to ``torch.Tensor``. + """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). Returns: - torch.Tensor: The materialized tensor. + torch.Tensor: The materialized tensor (self). """ target = self._materialize_data() - if isinstance(self, nn.Parameter): - target = nn.Parameter(target, requires_grad=self.requires_grad) - return target + self.clean() + return _convert_cls(self, target) + + def distribute(self, layout: Layout) -> torch.Tensor: + """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. + + Args: + layout (Layout): Distribution layout. + + Returns: + torch.Tensor: The distributed tensor (self). + """ + target = self._materialize_data() + self.clean() + local_tensor = DTensor(target, layout).local_tensor + return _convert_cls(self, local_tensor) def clean(self) -> None: """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. @@ -216,6 +257,8 @@ class LazyTensor(torch.Tensor): is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) or func.__name__ == "__setitem__") + is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS + if isinstance(func, torch._C.ScriptMethod): # FIXME(ver217): torch script functions are not verified @@ -239,10 +282,10 @@ class LazyTensor(torch.Tensor): if isinstance(x, LazyTensor): if x._materialized_data is not None: # for early materialized tensor, use its materialized data directly - return x._materialized_data.data + return x._materialized_data if is_change_meta_op else x._materialized_data.data t = x if is_inplace else x.clone() t._op_buffer.append((func, args, kwargs)) - meta = x._meta_data.data + meta = x._meta_data if is_change_meta_op else x._meta_data.data meta_to_lazy[meta] = t return meta return x @@ -290,13 +333,36 @@ class LazyTensor(torch.Tensor): @data.setter def data(self, other: 'LazyTensor'): + """This is sightly different from oringinal `data` setter. + + E.g.: + >>> a = torch.randn(3, 3) # a is a Tensor + >>> b = torch.rand(2, 2) + >>> a.data = b + >>> b.add_(1) # this will affect a + >>> x = torch.randn(3, 3) # x is a LazyTensor + >>> y = torch.rand(2, 2) # y is a LazyTensor + >>> x.data = y + >>> y.add_(1) # this will not affect x + + """ if other is self: return - # TODO(ver217): to avoid infinity recursion, do early materialization - self._materialized_data = other._materialize_data() + + self._op_buffer.append(other._factory_method) + + def replace(x): + if x is other: + return self + return x + + for func, args, kwargs in other._op_buffer: + self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) def tolist(self) -> list: - t = self.materialize() + # Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor + # And subclass of torch.Tensor does not have tolist() method + t = self._materialize_data() return t.tolist() def __hash__(self): @@ -421,71 +487,84 @@ class LazyInitContext: setattr(torch, name, orig) @staticmethod - def materialize(module: torch.nn.Module, verbose: bool = False): - """Initialize all ``nn.Parameter`` from ``LazyTensor``. + def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: + """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: - module (torch.nn.Module): Target ``nn.Module`` + module (nn.Module): Target ``nn.Module`` verbose (bool): Whether to print lazy initialization rate. Defaults to False. """ - if verbose: - param_cnt = 0 - param_lazy_cnt = 0 - buf_cnt = 0 - buf_lazy_cnt = 0 - non_lazy_numel = 0 - - # do post cleaning to handle shared parameter - visited_lazy_tensors: List[LazyTensor] = [] - # handle shared module - visited_modules = set() - - @torch.no_grad() - def init_recursively(module: nn.Module): - nonlocal param_cnt, param_lazy_cnt, buf_cnt, buf_lazy_cnt, non_lazy_numel - # recursively initialize the module - for mod in module.children(): - if id(mod) not in visited_modules: - visited_modules.add(id(mod)) - init_recursively(mod) - - # initialize tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - if verbose: - param_cnt += 1 - if getattr(param, '_materialized_data', False) is None: - # if no _materialized_data attr, the tensor is not lazy - param_lazy_cnt += 1 - else: - non_lazy_numel += param.numel() - if hasattr(param, 'materialize'): - # TODO(ver217): apex layers cannot be captured - visited_lazy_tensors.append(param) - setattr(module, name, param.materialize()) - - for name, buf in module.named_buffers(recurse=False): - if verbose: - buf_cnt += 1 - if getattr(buf, "_materialized_data", False) is None: - # if no _materialized_data attr, the tensor is not lazy - buf_lazy_cnt += 1 - else: - non_lazy_numel += buf.numel() - if hasattr(buf, 'materialize'): - # TODO(ver217): apex layers cannot be captured - visited_lazy_tensors.append(buf) - setattr(module, name, buf.materialize()) - init_recursively(module) + def apply_fn(name: str, p: LazyTensor): + p.materialize() + + return _apply_to_lazy_module(module, apply_fn, verbose) + + @staticmethod + def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. + + Args: + module (nn.Module): Target ``nn.Module`` + layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout. + verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False. + """ + + def apply_fn(name: str, p: LazyTensor): + p.distribute(layout_dict[name]) + + return _apply_to_lazy_module(module, apply_fn, verbose) + - for t in visited_lazy_tensors: - t.clean() +def _apply_to_lazy_module(module: nn.Module, + apply_fn: Callable[[str, torch.Tensor], None], + verbose: bool = False) -> nn.Module: + if verbose: + # verbose info + param_cnt = 0 + param_lazy_cnt = 0 + buf_cnt = 0 + buf_lazy_cnt = 0 + total_numel = 0 + non_lazy_numel = 0 + + for name, p in module.named_parameters(): + if verbose: + param_cnt += 1 + total_numel += p.numel() + if getattr(p, '_materialized_data', False) is None: + # if no _materialized_data attr, the tensor is not lazy + param_lazy_cnt += 1 + else: + non_lazy_numel += p.numel() + if isinstance(p, LazyTensor): + apply_fn(name, p) + for name, buf in module.named_buffers(): if verbose: - print(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') - print(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') - print(f'Non-lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M)') - return module + buf_cnt += 1 + total_numel += buf.numel() + if getattr(buf, "_materialized_data", False) is None: + # if no _materialized_data attr, the tensor is not lazy + buf_lazy_cnt += 1 + else: + non_lazy_numel += buf.numel() + if isinstance(buf, LazyTensor): + apply_fn(name, buf) + + if verbose: + non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 + _print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') + _print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') + _print_rank_0( + f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%') + + return module + + +def _print_rank_0(*args, **kwargs): + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs) def _is_int_tuple(args) -> bool: diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_utils/test_lazy_init/test_distribute.py new file mode 100644 index 000000000..37b2c5da1 --- /dev/null +++ b/tests/test_utils/test_lazy_init/test_distribute.py @@ -0,0 +1,110 @@ +from functools import partial +from typing import Optional + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.common import print_rank_0 +from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor +from tests.kit.model_zoo import model_zoo + +# from utils import assert_dist_model_equal, set_seed + + +def find_shard_dim(shape: torch.Size) -> Optional[int]: + for dim, size in enumerate(shape): + if size % 2 == 0: + return dim + + +def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: + shard_dim = find_shard_dim(original_tensor.shape) + dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=target_sharding_spec, + entire_shape=original_tensor.shape) + return layout + + +def _get_current_name(prefix: str, name: str) -> str: + return f'{prefix}.{name}'.lstrip('.') + + +def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: + layout_dict = {} + + @torch.no_grad() + def generate_recursively(module: nn.Module, prefix: str = ''): + # recursively initialize the module + for name, mod in module.named_children(): + generate_recursively(mod, prefix=_get_current_name(prefix, name)) + + # initialize tensors directly attached to the current module + for name, param in module.named_parameters(recurse=False): + if isinstance(param, LazyTensor): + layout = make_layout(device_mesh, param) + layout_dict[_get_current_name(prefix, name)] = layout + + for name, buf in module.named_buffers(recurse=False): + if isinstance(buf, LazyTensor): + layout = make_layout(device_mesh, buf) + layout_dict[_get_current_name(prefix, name)] = layout + + generate_recursively(model) + + return layout_dict + + +@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) +def run_dist_lazy_init(subset, seed: int = 42): + sub_model_zoo = model_zoo.get_sub_registry(subset) + device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) + # FIXME(ver217): uncomment this line + # _MyTensor._pre_op_fn = lambda *args: set_seed(seed) + # LazyTensor._pre_op_fn = lambda *args: set_seed(seed) + + for name, entry in sub_model_zoo.items(): + # TODO(ver217): lazy init does not support weight norm, skip these models + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + continue + print_rank_0(name) + model_fn, data_gen_fn, output_transform_fn, model_attr = entry + ctx = LazyInitContext(tensor_cls=_MyTensor) + with ctx: + model = model_fn() + ctx = LazyInitContext() + with ctx: + deferred_model = model_fn() + layout_dict = generate_layout_dict(deferred_model, device_mesh) + ctx.distribute(deferred_model, layout_dict, verbose=True) + # FIXME(ver217): uncomment this line + # assert_dist_model_equal(model, deferred_model, layout_dict) + + +def run_dist(rank, world_size, port) -> None: + colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port) + run_dist_lazy_init() + + +# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor +@pytest.mark.skip +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_lazy_init(): + world_size = 4 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_dist_lazy_init() diff --git a/tests/test_utils/test_lazy_init/utils.py b/tests/test_utils/test_lazy_init/utils.py index 47ba534bc..a8aeb4c89 100644 --- a/tests/test_utils/test_lazy_init/utils.py +++ b/tests/test_utils/test_lazy_init/utils.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Optional, Tuple import numpy as np import torch +from colossalai.tensor.d_tensor.layout_converter import to_global from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor from tests.kit.model_zoo.registry import ModelAttribute @@ -67,3 +68,18 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) if verbose: print(f'{model.__class__.__name__} pass') + + +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: + state = model.state_dict() + distributed_state = distributed_model.state_dict() + + assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}' + + for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()): + assert n1 == n2 + t1 = t1.cuda() + t2 = t2.cuda() + if n2 in layout_dict: + t2 = to_global(t2, layout_dict[n2]) + assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'