[lazy] support torch 2.0 (#4763)

* [lazy] support _like methods and clamp

* [lazy] pass transformers models

* [lazy] fix device move and requires grad

* [lazy] fix requires grad and refactor api

* [lazy] fix requires grad
pull/4773/head^2
Hongxin Liu 1 year ago committed by GitHub
parent 901ab1eedd
commit 3e05c07bb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,3 +4,4 @@ multi_line_output=3
include_trailing_comma = true include_trailing_comma = true
ignore_comments = true ignore_comments = true
profile = black profile = black
honor_noqa = true

@ -0,0 +1,87 @@
from contextlib import contextmanager
from typing import Callable, Dict, Tuple
import torch
__all__ = [
"_LEGACY_TENSOR_CONSTRUCTOR",
"_NO_META_FACTORY",
"_NORMAL_FACTORY",
"ConstructorManager",
]
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
"arange",
"full",
"empty",
"linspace",
"logspace",
"ones",
"rand",
"randn",
"randint",
"randperm",
"zeros",
"tensor",
]
# factory function that does not support meta tensor backend
_NO_META_FACTORY = [
"eye",
]
_LEGACY_TENSOR_CONSTRUCTOR = {
"FloatTensor": torch.float,
"DoubleTensor": torch.double,
"HalfTensor": torch.half,
"BFloat16Tensor": torch.bfloat16,
"ByteTensor": torch.uint8,
"CharTensor": torch.int8,
"ShortTensor": torch.short,
"IntTensor": torch.int,
"LongTensor": torch.long,
"BoolTensor": torch.bool,
}
class ConstructorManager:
# function name: (new, old)
overwrites: Dict[str, Tuple[Callable, Callable]] = {}
changed: bool = False
@staticmethod
def apply(overwrites: Dict[Callable, Callable]):
ConstructorManager.overwrites.clear()
ConstructorManager.overwrites.update(overwrites)
ConstructorManager.redo()
@staticmethod
def undo():
assert ConstructorManager.changed, "No constructor change to undo"
for name, (new, old) in ConstructorManager.overwrites.items():
setattr(torch, name, old)
ConstructorManager.changed = False
@staticmethod
def redo():
assert not ConstructorManager.changed, "Constructor already changed"
for name, (new, old) in ConstructorManager.overwrites.items():
setattr(torch, name, new)
ConstructorManager.changed = True
@staticmethod
@contextmanager
def disable():
enabled = ConstructorManager.changed
if enabled:
ConstructorManager.undo()
yield
if enabled:
ConstructorManager.redo()
@staticmethod
def clear():
if ConstructorManager.changed:
ConstructorManager.undo()
ConstructorManager.overwrites.clear()

@ -1,17 +1,18 @@
from types import MethodType from types import MethodType
from typing import Callable, Dict, Optional, Union from typing import Callable, Optional, Union
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from packaging import version
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor from colossalai.logging import get_dist_logger
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import distribute_tensor from .construction import ConstructorManager
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
import colossalai._analyzer._subclasses._meta_registration # noqa
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [ _NORMAL_FACTORY = [
@ -41,6 +42,9 @@ _EARLY_MATERIALIZED_OPS = ["__getitem__", "split"]
# These ops cannot be unwrapped using .data # These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] _CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"]
# These ops is not related to tensor value and should not be rerun
_NO_RERUN_OPS = ["__get__", "numel", "size", "dim"]
_LEGACY_TENSOR_CONSTRUCTOR = { _LEGACY_TENSOR_CONSTRUCTOR = {
"FloatTensor": torch.float, "FloatTensor": torch.float,
"DoubleTensor": torch.double, "DoubleTensor": torch.double,
@ -54,6 +58,20 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
"BoolTensor": torch.bool, "BoolTensor": torch.bool,
} }
# These ops have at least one lazy tensor argument and maybe a scalar argument
# scalar value should be converted to meta tensor
# this is a hack for torch 2.0
_EXPAND_SCALAR_OPS = [
"where",
"clamp",
"clamp_min",
"clamp_max",
"clamp_",
"clamp_min_",
"clamp_max_",
]
_old_tensor_factory = torch.tensor
_EMPTY_DATA = torch.empty(0) _EMPTY_DATA = torch.empty(0)
@ -145,34 +163,48 @@ class LazyTensor(torch.Tensor):
""" """
_repr = True _repr = True
_meta_data: Optional[MetaTensor] = None # shape, dtype, device _meta_data: Optional[torch.Tensor] = None # shape, dtype, device
_pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
default_device: Optional[torch.device] = None default_device: Optional[torch.device] = None
_device: torch.device # fake device of mate tensor
@staticmethod @staticmethod
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
# tips for torch 2.0:
# torch 2.0 disables torch dispatch for subclass of tensor
# MetaTensor is cannot be used
# Now lazy tensor contains device injection and meta tensor
if concrete_data is not None: if concrete_data is not None:
# some ops don't support meta backend and should have concrete data # some ops don't support meta backend and should have concrete data
elem = concrete_data elem = concrete_data
else: else:
if meta_data is None: if meta_data is None:
device = kwargs.get("device", "cpu") with ConstructorManager.disable():
elem = func(*args, **{**kwargs, "device": "meta"}) # to disable create lazy tensor in inner ops, this is a hack for torch 2.0
meta_data = MetaTensor(elem, device=device) meta_data = func(*args, **{**kwargs, "device": "meta"})
elem = meta_data._tensor elem = meta_data
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad) r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
r._meta_data = meta_data r._meta_data = meta_data
return r return r
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
self._device = torch.device(kwargs.get("device", None) or "cpu")
if func.__name__ in _NORMAL_FACTORY: if func.__name__ in _NORMAL_FACTORY:
kwargs = {**kwargs, "device": LazyTensor.default_device} kwargs = {**kwargs, "device": LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs) self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace) self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
@property
def device(self) -> torch.device:
return self._materialized_data.device if self._materialized_data is not None else self._device
def __repr__(self):
return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
def materialize(self) -> torch.Tensor: def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
@ -183,20 +215,6 @@ class LazyTensor(torch.Tensor):
self.clean() self.clean()
return _convert_cls(self, target) return _convert_cls(self, target)
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args:
layout (Layout): Distribution layout.
Returns:
torch.Tensor: The distributed tensor (self).
"""
target = self._materialize_data()
self.clean()
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
return _convert_cls(self, local_tensor)
def clean(self) -> None: def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
delattr(self, "_factory_method") delattr(self, "_factory_method")
@ -299,45 +317,80 @@ class LazyTensor(torch.Tensor):
# for early materialized tensor, use its materialized data directly # for early materialized tensor, use its materialized data directly
return x._materialized_data if is_change_meta_op else x._materialized_data.data return x._materialized_data if is_change_meta_op else x._materialized_data.data
t = x if is_inplace else x.clone() t = x if is_inplace else x.clone()
t._op_buffer.append((func, args, kwargs)) if func.__name__ not in _NO_RERUN_OPS:
t._op_buffer.append((func, args, kwargs))
meta = x._meta_data if is_change_meta_op else x._meta_data.data meta = x._meta_data if is_change_meta_op else x._meta_data.data
meta_to_lazy[meta] = t meta_to_lazy[meta] = t
return meta return meta
elif (
version.parse(torch.__version__) >= version.parse("2.0.0")
and func.__name__ in _EXPAND_SCALAR_OPS
and not isinstance(x, torch.Tensor)
):
return _old_tensor_factory(x, device="meta")
return x return x
def wrap(y, i=None): def wrap(y, i=None):
if isinstance(y, MetaTensor): if isinstance(y, torch.Tensor):
if y in meta_to_lazy: if y.is_meta:
# inplace op, just return origin lazy tensor if y in meta_to_lazy:
return meta_to_lazy[y] # inplace op, just return origin lazy tensor
return meta_to_lazy[y]
else:
# out of place op, create new lazy tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
fn.__name__ = func.__name__
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
else: else:
# out of place op, create new lazy tensor # for early materialized tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i] return LazyTensor(lambda: None, concrete_data=y)
fn.__name__ = func.__name__
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
elif type(y) is Tensor:
# for early materialized tensor
return LazyTensor(lambda: None, concrete_data=y)
return y return y
cls._pre_op_fn() cls._pre_op_fn()
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) with ConstructorManager.disable():
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
if isinstance(o, (tuple, list)): if isinstance(o, (tuple, list)):
return type(o)(wrap(y, i=i) for i, y in enumerate(o)) return type(o)(wrap(y, i=i) for i, y in enumerate(o))
return wrap(o) return wrap(o)
@classmethod def to(self, *args, **kwargs) -> torch.Tensor:
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if self._materialized_data is not None:
pass # skip return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs))
device = None
def replace(x):
nonlocal device
if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool):
device = x
return torch.device("meta")
return x
meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs))
if meta_data is self._meta_data and device == self.device:
return self
def factory_fn(t: torch.Tensor, **kw):
return t.to(*args, **kwargs)
return LazyTensor(factory_fn, self, meta_data=meta_data, device=device)
def cpu(self, memory_format: torch.memory_format = torch.preserve_format):
return self.to(device=torch.device("cpu"), memory_format=memory_format)
def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format):
device = torch.device(device or "cuda")
return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format)
def clone(self) -> "LazyTensor": def clone(self) -> "LazyTensor":
def factory_fn(): def factory_fn(t: torch.Tensor, **kw):
# if self is materialized, return self # if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self return t.clone()
return new_tensor.clone()
target = LazyTensor(factory_fn, meta_data=self._meta_data) target = LazyTensor(factory_fn, self, meta_data=self._meta_data)
return target return target
@ -353,17 +406,16 @@ class LazyTensor(torch.Tensor):
if id(self) in memo: if id(self) in memo:
return memo[id(self)] return memo[id(self)]
def factory_fn(): def factory_fn(t: torch.Tensor, **kw):
# if self is materialized, return self # if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self return _copy_tensor(t, t.requires_grad)
return _copy_tensor(new_tensor, new_tensor.requires_grad)
if self._materialized_data is not None: if self._materialized_data is not None:
# self is early materialized # self is early materialized
copied = _copy_tensor(self._materialized_data, self.requires_grad) copied = _copy_tensor(self._materialized_data, self.requires_grad)
target = LazyTensor(lambda: None, concrete_data=copied) target = LazyTensor(lambda: None, concrete_data=copied)
else: else:
target = LazyTensor(factory_fn, meta_data=self._meta_data) target = LazyTensor(factory_fn, self, meta_data=self._meta_data)
if isinstance(self, Parameter): if isinstance(self, Parameter):
# hack isinstance check of parameter # hack isinstance check of parameter
@ -394,14 +446,12 @@ class LazyTensor(torch.Tensor):
if other is self: if other is self:
return return
self._op_buffer.append(other._factory_method)
def replace(x): def replace(x):
if x is other: if x is other:
return self return self
return x return x
for func, args, kwargs in other._op_buffer: for func, args, kwargs in [other._factory_method, *other._op_buffer]:
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs))) self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
def tolist(self) -> list: def tolist(self) -> list:
@ -455,7 +505,6 @@ class LazyInitContext:
default_device: Optional[Union[torch.device, str, int]] = None, default_device: Optional[Union[torch.device, str, int]] = None,
): ):
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self.overrides = {}
self.tensor_cls = tensor_cls self.tensor_cls = tensor_cls
self.old_default_device = LazyTensor.default_device self.old_default_device = LazyTensor.default_device
self.default_device = default_device self.default_device = default_device
@ -478,7 +527,9 @@ class LazyInitContext:
# factory_like functions (eg. torch.empty_like()) # factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
orig_t = args[0] orig_t = args[0]
return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs) return self.tensor_cls(
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
)
return wrapper, target return wrapper, target
@ -513,13 +564,13 @@ class LazyInitContext:
return wrapper, target return wrapper, target
self.overrides = { overrides = {
target: wrap_factory_method(getattr(torch, target)) target: wrap_factory_method(getattr(torch, target))
for target in _NORMAL_FACTORY for target in _NORMAL_FACTORY
if callable(getattr(torch, target, None)) if callable(getattr(torch, target, None))
} }
self.overrides.update( overrides.update(
{ {
target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like"))
for target in _NORMAL_FACTORY for target in _NORMAL_FACTORY
@ -527,7 +578,7 @@ class LazyInitContext:
} }
) )
self.overrides.update( overrides.update(
{ {
target: wrap_legacy_constructor(getattr(torch, target), dtype) target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
@ -535,7 +586,7 @@ class LazyInitContext:
} }
) )
self.overrides.update( overrides.update(
{ {
target: wrap_no_meta_factory(getattr(torch, target)) target: wrap_no_meta_factory(getattr(torch, target))
for target in _NO_META_FACTORY for target in _NO_META_FACTORY
@ -543,14 +594,12 @@ class LazyInitContext:
} }
) )
for name, (wrapper, orig) in self.overrides.items(): ConstructorManager.apply(overrides)
setattr(torch, name, wrapper)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.tensor_cls.default_device = self.old_default_device self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False LazyInitContext._replaced = False
for name, (wrapper, orig) in self.overrides.items(): ConstructorManager.clear()
setattr(torch, name, orig)
@staticmethod @staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module: def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
@ -566,23 +615,6 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod
def distribute(
module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False
) -> nn.Module:
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (nn.Module): Target ``nn.Module``
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
"""
def apply_fn(name: str, p: LazyTensor):
p.distribute(device_mesh, sharding_spec_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose)
def _apply_to_lazy_module( def _apply_to_lazy_module(
module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False
@ -622,20 +654,17 @@ def _apply_to_lazy_module(
if verbose: if verbose:
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
_print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}") logger = get_dist_logger()
_print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}") logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0])
_print_rank_0( logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0])
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%" logger.info(
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%",
ranks=[0],
) )
return module return module
def _print_rank_0(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
def _is_int_tuple(args) -> bool: def _is_int_tuple(args) -> bool:
if not isinstance(args, tuple): if not isinstance(args, tuple):
return False return False

@ -11,14 +11,12 @@ def test_torchvision_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset) sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items(): for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models # TODO(ver217): lazy init does not support weight norm, skip these models
if ( if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") ("transformers_vit", "transformers_blip2")
or name.startswith("transformers_llama")
or name.startswith(("transformers_vit", "transformers_blip2"))
): ):
continue continue
check_lazy_init(entry, verbose=True, default_device=default_device) check_lazy_init(entry, verbose=True, default_device=default_device)
if __name__ == "__main__": if __name__ == "__main__":
test_torchvision_models_lazy_init("torchvision") test_torchvision_models_lazy_init("transformers", "cpu")

@ -0,0 +1,64 @@
import copy
import pytest
import torch
import torch.nn as nn
from lazy_init_utils import SUPPORT_LAZY
from torch.nn import Parameter
from colossalai.lazy import LazyInitContext
@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0")
def test_lazy_ops():
with LazyInitContext():
x = torch.rand(2, 3)
assert tuple(x.shape) == (2, 3)
assert x.device.type == "cpu"
x.requires_grad is False
y = x.cuda()
assert tuple(y.shape) == (2, 3)
assert y.device.type == "cuda"
assert y.requires_grad is False
assert x.cpu() is x
p = Parameter(torch.empty(2, 3))
assert tuple(p.shape) == (2, 3)
assert p.device.type == "cpu"
assert p.requires_grad is True
assert isinstance(p, Parameter)
x.materialize()
assert tuple(x.shape) == (2, 3)
assert x.device.type == "cpu"
assert x.requires_grad is False
y.materialize()
assert tuple(y.shape) == (2, 3)
assert y.device.type == "cuda"
assert y.requires_grad is False
p.materialize()
assert tuple(p.shape) == (2, 3)
assert p.device.type == "cpu"
assert p.requires_grad is True
assert isinstance(p, Parameter)
with LazyInitContext():
x = torch.empty(2, 3)
x.uniform_()
x.materialize()
assert tuple(x.shape) == (2, 3)
with LazyInitContext():
model = nn.Linear(3, 4)
model = model.cuda()
model_copied = copy.deepcopy(model)
LazyInitContext.materialize(model)
assert model.weight.device.type == "cuda"
assert model.bias.device.type == "cuda"
LazyInitContext.materialize(model_copied)
assert model_copied.weight.device.type == "cuda"
assert model_copied.bias.device.type == "cuda"
assert torch.equal(model.weight, model_copied.weight)
assert torch.equal(model.bias, model_copied.bias)
if __name__ == "__main__":
test_lazy_ops()
Loading…
Cancel
Save