[lazy] support torch 2.0 (#4763)

* [lazy] support _like methods and clamp

* [lazy] pass transformers models

* [lazy] fix device move and requires grad

* [lazy] fix requires grad and refactor api

* [lazy] fix requires grad
pull/4773/head^2
Hongxin Liu 2023-09-21 16:30:23 +08:00 committed by GitHub
parent 901ab1eedd
commit 3e05c07bb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 273 additions and 94 deletions

View File

@ -4,3 +4,4 @@ multi_line_output=3
include_trailing_comma = true
ignore_comments = true
profile = black
honor_noqa = true

View File

@ -0,0 +1,87 @@
from contextlib import contextmanager
from typing import Callable, Dict, Tuple
import torch
__all__ = [
"_LEGACY_TENSOR_CONSTRUCTOR",
"_NO_META_FACTORY",
"_NORMAL_FACTORY",
"ConstructorManager",
]
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
"arange",
"full",
"empty",
"linspace",
"logspace",
"ones",
"rand",
"randn",
"randint",
"randperm",
"zeros",
"tensor",
]
# factory function that does not support meta tensor backend
_NO_META_FACTORY = [
"eye",
]
_LEGACY_TENSOR_CONSTRUCTOR = {
"FloatTensor": torch.float,
"DoubleTensor": torch.double,
"HalfTensor": torch.half,
"BFloat16Tensor": torch.bfloat16,
"ByteTensor": torch.uint8,
"CharTensor": torch.int8,
"ShortTensor": torch.short,
"IntTensor": torch.int,
"LongTensor": torch.long,
"BoolTensor": torch.bool,
}
class ConstructorManager:
# function name: (new, old)
overwrites: Dict[str, Tuple[Callable, Callable]] = {}
changed: bool = False
@staticmethod
def apply(overwrites: Dict[Callable, Callable]):
ConstructorManager.overwrites.clear()
ConstructorManager.overwrites.update(overwrites)
ConstructorManager.redo()
@staticmethod
def undo():
assert ConstructorManager.changed, "No constructor change to undo"
for name, (new, old) in ConstructorManager.overwrites.items():
setattr(torch, name, old)
ConstructorManager.changed = False
@staticmethod
def redo():
assert not ConstructorManager.changed, "Constructor already changed"
for name, (new, old) in ConstructorManager.overwrites.items():
setattr(torch, name, new)
ConstructorManager.changed = True
@staticmethod
@contextmanager
def disable():
enabled = ConstructorManager.changed
if enabled:
ConstructorManager.undo()
yield
if enabled:
ConstructorManager.redo()
@staticmethod
def clear():
if ConstructorManager.changed:
ConstructorManager.undo()
ConstructorManager.overwrites.clear()

View File

@ -1,17 +1,18 @@
from types import MethodType
from typing import Callable, Dict, Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from packaging import version
from torch import Tensor
from torch.nn import Parameter
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import distribute_tensor
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.logging import get_dist_logger
from .construction import ConstructorManager
import colossalai._analyzer._subclasses._meta_registration # noqa
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
@ -41,6 +42,9 @@ _EARLY_MATERIALIZED_OPS = ["__getitem__", "split"]
# These ops cannot be unwrapped using .data
_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"]
# These ops is not related to tensor value and should not be rerun
_NO_RERUN_OPS = ["__get__", "numel", "size", "dim"]
_LEGACY_TENSOR_CONSTRUCTOR = {
"FloatTensor": torch.float,
"DoubleTensor": torch.double,
@ -54,6 +58,20 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
"BoolTensor": torch.bool,
}
# These ops have at least one lazy tensor argument and maybe a scalar argument
# scalar value should be converted to meta tensor
# this is a hack for torch 2.0
_EXPAND_SCALAR_OPS = [
"where",
"clamp",
"clamp_min",
"clamp_max",
"clamp_",
"clamp_min_",
"clamp_max_",
]
_old_tensor_factory = torch.tensor
_EMPTY_DATA = torch.empty(0)
@ -145,34 +163,48 @@ class LazyTensor(torch.Tensor):
"""
_repr = True
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
_meta_data: Optional[torch.Tensor] = None # shape, dtype, device
_pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
default_device: Optional[torch.device] = None
_device: torch.device # fake device of mate tensor
@staticmethod
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
# tips for torch 2.0:
# torch 2.0 disables torch dispatch for subclass of tensor
# MetaTensor is cannot be used
# Now lazy tensor contains device injection and meta tensor
if concrete_data is not None:
# some ops don't support meta backend and should have concrete data
elem = concrete_data
else:
if meta_data is None:
device = kwargs.get("device", "cpu")
elem = func(*args, **{**kwargs, "device": "meta"})
meta_data = MetaTensor(elem, device=device)
elem = meta_data._tensor
with ConstructorManager.disable():
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
meta_data = func(*args, **{**kwargs, "device": "meta"})
elem = meta_data
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
r._meta_data = meta_data
return r
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
self._device = torch.device(kwargs.get("device", None) or "cpu")
if func.__name__ in _NORMAL_FACTORY:
kwargs = {**kwargs, "device": LazyTensor.default_device}
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
self._op_buffer = [] # (func, args, kwargs, replace)
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
@property
def device(self) -> torch.device:
return self._materialized_data.device if self._materialized_data is not None else self._device
def __repr__(self):
return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
def materialize(self) -> torch.Tensor:
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
@ -183,20 +215,6 @@ class LazyTensor(torch.Tensor):
self.clean()
return _convert_cls(self, target)
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args:
layout (Layout): Distribution layout.
Returns:
torch.Tensor: The distributed tensor (self).
"""
target = self._materialize_data()
self.clean()
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
return _convert_cls(self, local_tensor)
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
delattr(self, "_factory_method")
@ -299,45 +317,80 @@ class LazyTensor(torch.Tensor):
# for early materialized tensor, use its materialized data directly
return x._materialized_data if is_change_meta_op else x._materialized_data.data
t = x if is_inplace else x.clone()
t._op_buffer.append((func, args, kwargs))
if func.__name__ not in _NO_RERUN_OPS:
t._op_buffer.append((func, args, kwargs))
meta = x._meta_data if is_change_meta_op else x._meta_data.data
meta_to_lazy[meta] = t
return meta
elif (
version.parse(torch.__version__) >= version.parse("2.0.0")
and func.__name__ in _EXPAND_SCALAR_OPS
and not isinstance(x, torch.Tensor)
):
return _old_tensor_factory(x, device="meta")
return x
def wrap(y, i=None):
if isinstance(y, MetaTensor):
if y in meta_to_lazy:
# inplace op, just return origin lazy tensor
return meta_to_lazy[y]
if isinstance(y, torch.Tensor):
if y.is_meta:
if y in meta_to_lazy:
# inplace op, just return origin lazy tensor
return meta_to_lazy[y]
else:
# out of place op, create new lazy tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
fn.__name__ = func.__name__
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
else:
# out of place op, create new lazy tensor
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
fn.__name__ = func.__name__
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
return lazy_y
elif type(y) is Tensor:
# for early materialized tensor
return LazyTensor(lambda: None, concrete_data=y)
# for early materialized tensor
return LazyTensor(lambda: None, concrete_data=y)
return y
cls._pre_op_fn()
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
with ConstructorManager.disable():
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
if isinstance(o, (tuple, list)):
return type(o)(wrap(y, i=i) for i, y in enumerate(o))
return wrap(o)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass # skip
def to(self, *args, **kwargs) -> torch.Tensor:
if self._materialized_data is not None:
return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs))
device = None
def replace(x):
nonlocal device
if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool):
device = x
return torch.device("meta")
return x
meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs))
if meta_data is self._meta_data and device == self.device:
return self
def factory_fn(t: torch.Tensor, **kw):
return t.to(*args, **kwargs)
return LazyTensor(factory_fn, self, meta_data=meta_data, device=device)
def cpu(self, memory_format: torch.memory_format = torch.preserve_format):
return self.to(device=torch.device("cpu"), memory_format=memory_format)
def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format):
device = torch.device(device or "cuda")
return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format)
def clone(self) -> "LazyTensor":
def factory_fn():
def factory_fn(t: torch.Tensor, **kw):
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
return new_tensor.clone()
return t.clone()
target = LazyTensor(factory_fn, meta_data=self._meta_data)
target = LazyTensor(factory_fn, self, meta_data=self._meta_data)
return target
@ -353,17 +406,16 @@ class LazyTensor(torch.Tensor):
if id(self) in memo:
return memo[id(self)]
def factory_fn():
def factory_fn(t: torch.Tensor, **kw):
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
return _copy_tensor(new_tensor, new_tensor.requires_grad)
return _copy_tensor(t, t.requires_grad)
if self._materialized_data is not None:
# self is early materialized
copied = _copy_tensor(self._materialized_data, self.requires_grad)
target = LazyTensor(lambda: None, concrete_data=copied)
else:
target = LazyTensor(factory_fn, meta_data=self._meta_data)
target = LazyTensor(factory_fn, self, meta_data=self._meta_data)
if isinstance(self, Parameter):
# hack isinstance check of parameter
@ -394,14 +446,12 @@ class LazyTensor(torch.Tensor):
if other is self:
return
self._op_buffer.append(other._factory_method)
def replace(x):
if x is other:
return self
return x
for func, args, kwargs in other._op_buffer:
for func, args, kwargs in [other._factory_method, *other._op_buffer]:
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
def tolist(self) -> list:
@ -455,7 +505,6 @@ class LazyInitContext:
default_device: Optional[Union[torch.device, str, int]] = None,
):
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
self.overrides = {}
self.tensor_cls = tensor_cls
self.old_default_device = LazyTensor.default_device
self.default_device = default_device
@ -478,7 +527,9 @@ class LazyInitContext:
# factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs):
orig_t = args[0]
return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs)
return self.tensor_cls(
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
)
return wrapper, target
@ -513,13 +564,13 @@ class LazyInitContext:
return wrapper, target
self.overrides = {
overrides = {
target: wrap_factory_method(getattr(torch, target))
for target in _NORMAL_FACTORY
if callable(getattr(torch, target, None))
}
self.overrides.update(
overrides.update(
{
target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like"))
for target in _NORMAL_FACTORY
@ -527,7 +578,7 @@ class LazyInitContext:
}
)
self.overrides.update(
overrides.update(
{
target: wrap_legacy_constructor(getattr(torch, target), dtype)
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
@ -535,7 +586,7 @@ class LazyInitContext:
}
)
self.overrides.update(
overrides.update(
{
target: wrap_no_meta_factory(getattr(torch, target))
for target in _NO_META_FACTORY
@ -543,14 +594,12 @@ class LazyInitContext:
}
)
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, wrapper)
ConstructorManager.apply(overrides)
def __exit__(self, exc_type, exc_val, exc_tb):
self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False
for name, (wrapper, orig) in self.overrides.items():
setattr(torch, name, orig)
ConstructorManager.clear()
@staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
@ -566,23 +615,6 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod
def distribute(
module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False
) -> nn.Module:
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (nn.Module): Target ``nn.Module``
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
"""
def apply_fn(name: str, p: LazyTensor):
p.distribute(device_mesh, sharding_spec_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose)
def _apply_to_lazy_module(
module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False
@ -622,20 +654,17 @@ def _apply_to_lazy_module(
if verbose:
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
_print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}")
_print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}")
_print_rank_0(
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%"
logger = get_dist_logger()
logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0])
logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0])
logger.info(
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%",
ranks=[0],
)
return module
def _print_rank_0(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
print(*args, **kwargs)
def _is_int_tuple(args) -> bool:
if not isinstance(args, tuple):
return False

View File

@ -11,14 +11,12 @@ def test_torchvision_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if (
name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base")
or name.startswith("transformers_llama")
or name.startswith(("transformers_vit", "transformers_blip2"))
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
("transformers_vit", "transformers_blip2")
):
continue
check_lazy_init(entry, verbose=True, default_device=default_device)
if __name__ == "__main__":
test_torchvision_models_lazy_init("torchvision")
test_torchvision_models_lazy_init("transformers", "cpu")

View File

@ -0,0 +1,64 @@
import copy
import pytest
import torch
import torch.nn as nn
from lazy_init_utils import SUPPORT_LAZY
from torch.nn import Parameter
from colossalai.lazy import LazyInitContext
@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0")
def test_lazy_ops():
with LazyInitContext():
x = torch.rand(2, 3)
assert tuple(x.shape) == (2, 3)
assert x.device.type == "cpu"
x.requires_grad is False
y = x.cuda()
assert tuple(y.shape) == (2, 3)
assert y.device.type == "cuda"
assert y.requires_grad is False
assert x.cpu() is x
p = Parameter(torch.empty(2, 3))
assert tuple(p.shape) == (2, 3)
assert p.device.type == "cpu"
assert p.requires_grad is True
assert isinstance(p, Parameter)
x.materialize()
assert tuple(x.shape) == (2, 3)
assert x.device.type == "cpu"
assert x.requires_grad is False
y.materialize()
assert tuple(y.shape) == (2, 3)
assert y.device.type == "cuda"
assert y.requires_grad is False
p.materialize()
assert tuple(p.shape) == (2, 3)
assert p.device.type == "cpu"
assert p.requires_grad is True
assert isinstance(p, Parameter)
with LazyInitContext():
x = torch.empty(2, 3)
x.uniform_()
x.materialize()
assert tuple(x.shape) == (2, 3)
with LazyInitContext():
model = nn.Linear(3, 4)
model = model.cuda()
model_copied = copy.deepcopy(model)
LazyInitContext.materialize(model)
assert model.weight.device.type == "cuda"
assert model.bias.device.type == "cuda"
LazyInitContext.materialize(model_copied)
assert model_copied.weight.device.type == "cuda"
assert model_copied.bias.device.type == "cuda"
assert torch.equal(model.weight, model_copied.weight)
assert torch.equal(model.bias, model_copied.bias)
if __name__ == "__main__":
test_lazy_ops()