mirror of https://github.com/hpcaitech/ColossalAI
[lazy] support torch 2.0 (#4763)
* [lazy] support _like methods and clamp * [lazy] pass transformers models * [lazy] fix device move and requires grad * [lazy] fix requires grad and refactor api * [lazy] fix requires gradpull/4773/head^2
parent
901ab1eedd
commit
3e05c07bb8
|
@ -4,3 +4,4 @@ multi_line_output=3
|
|||
include_trailing_comma = true
|
||||
ignore_comments = true
|
||||
profile = black
|
||||
honor_noqa = true
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"_LEGACY_TENSOR_CONSTRUCTOR",
|
||||
"_NO_META_FACTORY",
|
||||
"_NORMAL_FACTORY",
|
||||
"ConstructorManager",
|
||||
]
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
"arange",
|
||||
"full",
|
||||
"empty",
|
||||
"linspace",
|
||||
"logspace",
|
||||
"ones",
|
||||
"rand",
|
||||
"randn",
|
||||
"randint",
|
||||
"randperm",
|
||||
"zeros",
|
||||
"tensor",
|
||||
]
|
||||
|
||||
# factory function that does not support meta tensor backend
|
||||
_NO_META_FACTORY = [
|
||||
"eye",
|
||||
]
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
"FloatTensor": torch.float,
|
||||
"DoubleTensor": torch.double,
|
||||
"HalfTensor": torch.half,
|
||||
"BFloat16Tensor": torch.bfloat16,
|
||||
"ByteTensor": torch.uint8,
|
||||
"CharTensor": torch.int8,
|
||||
"ShortTensor": torch.short,
|
||||
"IntTensor": torch.int,
|
||||
"LongTensor": torch.long,
|
||||
"BoolTensor": torch.bool,
|
||||
}
|
||||
|
||||
|
||||
class ConstructorManager:
|
||||
# function name: (new, old)
|
||||
overwrites: Dict[str, Tuple[Callable, Callable]] = {}
|
||||
changed: bool = False
|
||||
|
||||
@staticmethod
|
||||
def apply(overwrites: Dict[Callable, Callable]):
|
||||
ConstructorManager.overwrites.clear()
|
||||
ConstructorManager.overwrites.update(overwrites)
|
||||
ConstructorManager.redo()
|
||||
|
||||
@staticmethod
|
||||
def undo():
|
||||
assert ConstructorManager.changed, "No constructor change to undo"
|
||||
for name, (new, old) in ConstructorManager.overwrites.items():
|
||||
setattr(torch, name, old)
|
||||
ConstructorManager.changed = False
|
||||
|
||||
@staticmethod
|
||||
def redo():
|
||||
assert not ConstructorManager.changed, "Constructor already changed"
|
||||
for name, (new, old) in ConstructorManager.overwrites.items():
|
||||
setattr(torch, name, new)
|
||||
ConstructorManager.changed = True
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def disable():
|
||||
enabled = ConstructorManager.changed
|
||||
if enabled:
|
||||
ConstructorManager.undo()
|
||||
yield
|
||||
if enabled:
|
||||
ConstructorManager.redo()
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
if ConstructorManager.changed:
|
||||
ConstructorManager.undo()
|
||||
ConstructorManager.overwrites.clear()
|
|
@ -1,17 +1,18 @@
|
|||
from types import MethodType
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor import distribute_tensor
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .construction import ConstructorManager
|
||||
|
||||
import colossalai._analyzer._subclasses._meta_registration # noqa
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
|
@ -41,6 +42,9 @@ _EARLY_MATERIALIZED_OPS = ["__getitem__", "split"]
|
|||
# These ops cannot be unwrapped using .data
|
||||
_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"]
|
||||
|
||||
# These ops is not related to tensor value and should not be rerun
|
||||
_NO_RERUN_OPS = ["__get__", "numel", "size", "dim"]
|
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||||
"FloatTensor": torch.float,
|
||||
"DoubleTensor": torch.double,
|
||||
|
@ -54,6 +58,20 @@ _LEGACY_TENSOR_CONSTRUCTOR = {
|
|||
"BoolTensor": torch.bool,
|
||||
}
|
||||
|
||||
# These ops have at least one lazy tensor argument and maybe a scalar argument
|
||||
# scalar value should be converted to meta tensor
|
||||
# this is a hack for torch 2.0
|
||||
_EXPAND_SCALAR_OPS = [
|
||||
"where",
|
||||
"clamp",
|
||||
"clamp_min",
|
||||
"clamp_max",
|
||||
"clamp_",
|
||||
"clamp_min_",
|
||||
"clamp_max_",
|
||||
]
|
||||
_old_tensor_factory = torch.tensor
|
||||
|
||||
_EMPTY_DATA = torch.empty(0)
|
||||
|
||||
|
||||
|
@ -145,34 +163,48 @@ class LazyTensor(torch.Tensor):
|
|||
"""
|
||||
|
||||
_repr = True
|
||||
_meta_data: Optional[MetaTensor] = None # shape, dtype, device
|
||||
_meta_data: Optional[torch.Tensor] = None # shape, dtype, device
|
||||
_pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None
|
||||
|
||||
default_device: Optional[torch.device] = None
|
||||
_device: torch.device # fake device of mate tensor
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||||
# tips for torch 2.0:
|
||||
# torch 2.0 disables torch dispatch for subclass of tensor
|
||||
# MetaTensor is cannot be used
|
||||
# Now lazy tensor contains device injection and meta tensor
|
||||
if concrete_data is not None:
|
||||
# some ops don't support meta backend and should have concrete data
|
||||
elem = concrete_data
|
||||
else:
|
||||
if meta_data is None:
|
||||
device = kwargs.get("device", "cpu")
|
||||
elem = func(*args, **{**kwargs, "device": "meta"})
|
||||
meta_data = MetaTensor(elem, device=device)
|
||||
elem = meta_data._tensor
|
||||
with ConstructorManager.disable():
|
||||
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
|
||||
meta_data = func(*args, **{**kwargs, "device": "meta"})
|
||||
elem = meta_data
|
||||
# As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
|
||||
r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
|
||||
r._meta_data = meta_data
|
||||
|
||||
return r
|
||||
|
||||
def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
|
||||
self._device = torch.device(kwargs.get("device", None) or "cpu")
|
||||
if func.__name__ in _NORMAL_FACTORY:
|
||||
kwargs = {**kwargs, "device": LazyTensor.default_device}
|
||||
self._factory_method = (func, args, kwargs) # (func, args, kwargs)
|
||||
self._op_buffer = [] # (func, args, kwargs, replace)
|
||||
self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._materialized_data.device if self._materialized_data is not None else self._device
|
||||
|
||||
def __repr__(self):
|
||||
return f"LazyTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
||||
|
||||
def materialize(self) -> torch.Tensor:
|
||||
"""Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
|
||||
|
||||
|
@ -183,20 +215,6 @@ class LazyTensor(torch.Tensor):
|
|||
self.clean()
|
||||
return _convert_cls(self, target)
|
||||
|
||||
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
|
||||
|
||||
Args:
|
||||
layout (Layout): Distribution layout.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor (self).
|
||||
"""
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
|
||||
delattr(self, "_factory_method")
|
||||
|
@ -299,45 +317,80 @@ class LazyTensor(torch.Tensor):
|
|||
# for early materialized tensor, use its materialized data directly
|
||||
return x._materialized_data if is_change_meta_op else x._materialized_data.data
|
||||
t = x if is_inplace else x.clone()
|
||||
t._op_buffer.append((func, args, kwargs))
|
||||
if func.__name__ not in _NO_RERUN_OPS:
|
||||
t._op_buffer.append((func, args, kwargs))
|
||||
meta = x._meta_data if is_change_meta_op else x._meta_data.data
|
||||
meta_to_lazy[meta] = t
|
||||
return meta
|
||||
elif (
|
||||
version.parse(torch.__version__) >= version.parse("2.0.0")
|
||||
and func.__name__ in _EXPAND_SCALAR_OPS
|
||||
and not isinstance(x, torch.Tensor)
|
||||
):
|
||||
return _old_tensor_factory(x, device="meta")
|
||||
return x
|
||||
|
||||
def wrap(y, i=None):
|
||||
if isinstance(y, MetaTensor):
|
||||
if y in meta_to_lazy:
|
||||
# inplace op, just return origin lazy tensor
|
||||
return meta_to_lazy[y]
|
||||
if isinstance(y, torch.Tensor):
|
||||
if y.is_meta:
|
||||
if y in meta_to_lazy:
|
||||
# inplace op, just return origin lazy tensor
|
||||
return meta_to_lazy[y]
|
||||
else:
|
||||
# out of place op, create new lazy tensor
|
||||
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
|
||||
fn.__name__ = func.__name__
|
||||
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
|
||||
return lazy_y
|
||||
else:
|
||||
# out of place op, create new lazy tensor
|
||||
fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
|
||||
fn.__name__ = func.__name__
|
||||
lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
|
||||
return lazy_y
|
||||
elif type(y) is Tensor:
|
||||
# for early materialized tensor
|
||||
return LazyTensor(lambda: None, concrete_data=y)
|
||||
# for early materialized tensor
|
||||
return LazyTensor(lambda: None, concrete_data=y)
|
||||
return y
|
||||
|
||||
cls._pre_op_fn()
|
||||
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
with ConstructorManager.disable():
|
||||
# to disable create lazy tensor in inner ops, this is a hack for torch 2.0
|
||||
o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
if isinstance(o, (tuple, list)):
|
||||
return type(o)(wrap(y, i=i) for i, y in enumerate(o))
|
||||
return wrap(o)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
pass # skip
|
||||
def to(self, *args, **kwargs) -> torch.Tensor:
|
||||
if self._materialized_data is not None:
|
||||
return LazyTensor(lambda: None, concrete_data=self._materialized_data.to(*args, **kwargs))
|
||||
|
||||
device = None
|
||||
|
||||
def replace(x):
|
||||
nonlocal device
|
||||
if isinstance(x, (str, int, torch.device)) and not isinstance(x, bool):
|
||||
device = x
|
||||
return torch.device("meta")
|
||||
return x
|
||||
|
||||
meta_data = self._meta_data.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
||||
|
||||
if meta_data is self._meta_data and device == self.device:
|
||||
return self
|
||||
|
||||
def factory_fn(t: torch.Tensor, **kw):
|
||||
return t.to(*args, **kwargs)
|
||||
|
||||
return LazyTensor(factory_fn, self, meta_data=meta_data, device=device)
|
||||
|
||||
def cpu(self, memory_format: torch.memory_format = torch.preserve_format):
|
||||
return self.to(device=torch.device("cpu"), memory_format=memory_format)
|
||||
|
||||
def cuda(self, device=None, non_blocking=False, memory_format: torch.memory_format = torch.preserve_format):
|
||||
device = torch.device(device or "cuda")
|
||||
return self.to(device=device, non_blocking=non_blocking, memory_format=memory_format)
|
||||
|
||||
def clone(self) -> "LazyTensor":
|
||||
def factory_fn():
|
||||
def factory_fn(t: torch.Tensor, **kw):
|
||||
# if self is materialized, return self
|
||||
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||
return new_tensor.clone()
|
||||
return t.clone()
|
||||
|
||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||
target = LazyTensor(factory_fn, self, meta_data=self._meta_data)
|
||||
|
||||
return target
|
||||
|
||||
|
@ -353,17 +406,16 @@ class LazyTensor(torch.Tensor):
|
|||
if id(self) in memo:
|
||||
return memo[id(self)]
|
||||
|
||||
def factory_fn():
|
||||
def factory_fn(t: torch.Tensor, **kw):
|
||||
# if self is materialized, return self
|
||||
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||
return _copy_tensor(new_tensor, new_tensor.requires_grad)
|
||||
return _copy_tensor(t, t.requires_grad)
|
||||
|
||||
if self._materialized_data is not None:
|
||||
# self is early materialized
|
||||
copied = _copy_tensor(self._materialized_data, self.requires_grad)
|
||||
target = LazyTensor(lambda: None, concrete_data=copied)
|
||||
else:
|
||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||
target = LazyTensor(factory_fn, self, meta_data=self._meta_data)
|
||||
|
||||
if isinstance(self, Parameter):
|
||||
# hack isinstance check of parameter
|
||||
|
@ -394,14 +446,12 @@ class LazyTensor(torch.Tensor):
|
|||
if other is self:
|
||||
return
|
||||
|
||||
self._op_buffer.append(other._factory_method)
|
||||
|
||||
def replace(x):
|
||||
if x is other:
|
||||
return self
|
||||
return x
|
||||
|
||||
for func, args, kwargs in other._op_buffer:
|
||||
for func, args, kwargs in [other._factory_method, *other._op_buffer]:
|
||||
self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
|
||||
|
||||
def tolist(self) -> list:
|
||||
|
@ -455,7 +505,6 @@ class LazyInitContext:
|
|||
default_device: Optional[Union[torch.device, str, int]] = None,
|
||||
):
|
||||
assert tensor_cls is LazyTensor or tensor_cls is _MyTensor
|
||||
self.overrides = {}
|
||||
self.tensor_cls = tensor_cls
|
||||
self.old_default_device = LazyTensor.default_device
|
||||
self.default_device = default_device
|
||||
|
@ -478,7 +527,9 @@ class LazyInitContext:
|
|||
# factory_like functions (eg. torch.empty_like())
|
||||
def wrapper(*args, **kwargs):
|
||||
orig_t = args[0]
|
||||
return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs)
|
||||
return self.tensor_cls(
|
||||
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
|
||||
)
|
||||
|
||||
return wrapper, target
|
||||
|
||||
|
@ -513,13 +564,13 @@ class LazyInitContext:
|
|||
|
||||
return wrapper, target
|
||||
|
||||
self.overrides = {
|
||||
overrides = {
|
||||
target: wrap_factory_method(getattr(torch, target))
|
||||
for target in _NORMAL_FACTORY
|
||||
if callable(getattr(torch, target, None))
|
||||
}
|
||||
|
||||
self.overrides.update(
|
||||
overrides.update(
|
||||
{
|
||||
target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like"))
|
||||
for target in _NORMAL_FACTORY
|
||||
|
@ -527,7 +578,7 @@ class LazyInitContext:
|
|||
}
|
||||
)
|
||||
|
||||
self.overrides.update(
|
||||
overrides.update(
|
||||
{
|
||||
target: wrap_legacy_constructor(getattr(torch, target), dtype)
|
||||
for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
|
||||
|
@ -535,7 +586,7 @@ class LazyInitContext:
|
|||
}
|
||||
)
|
||||
|
||||
self.overrides.update(
|
||||
overrides.update(
|
||||
{
|
||||
target: wrap_no_meta_factory(getattr(torch, target))
|
||||
for target in _NO_META_FACTORY
|
||||
|
@ -543,14 +594,12 @@ class LazyInitContext:
|
|||
}
|
||||
)
|
||||
|
||||
for name, (wrapper, orig) in self.overrides.items():
|
||||
setattr(torch, name, wrapper)
|
||||
ConstructorManager.apply(overrides)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.tensor_cls.default_device = self.old_default_device
|
||||
LazyInitContext._replaced = False
|
||||
for name, (wrapper, orig) in self.overrides.items():
|
||||
setattr(torch, name, orig)
|
||||
ConstructorManager.clear()
|
||||
|
||||
@staticmethod
|
||||
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||||
|
@ -566,23 +615,6 @@ class LazyInitContext:
|
|||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
@staticmethod
|
||||
def distribute(
|
||||
module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False
|
||||
) -> nn.Module:
|
||||
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Target ``nn.Module``
|
||||
layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
|
||||
verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
|
||||
"""
|
||||
|
||||
def apply_fn(name: str, p: LazyTensor):
|
||||
p.distribute(device_mesh, sharding_spec_dict[name])
|
||||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
|
||||
def _apply_to_lazy_module(
|
||||
module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False
|
||||
|
@ -622,20 +654,17 @@ def _apply_to_lazy_module(
|
|||
|
||||
if verbose:
|
||||
non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
|
||||
_print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}")
|
||||
_print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}")
|
||||
_print_rank_0(
|
||||
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%"
|
||||
logger = get_dist_logger()
|
||||
logger.info(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}", ranks=[0])
|
||||
logger.info(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}", ranks=[0])
|
||||
logger.info(
|
||||
f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def _print_rank_0(*args, **kwargs):
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_int_tuple(args) -> bool:
|
||||
if not isinstance(args, tuple):
|
||||
return False
|
||||
|
|
|
@ -11,14 +11,12 @@ def test_torchvision_models_lazy_init(subset, default_device):
|
|||
sub_model_zoo = model_zoo.get_sub_registry(subset)
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if (
|
||||
name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base")
|
||||
or name.startswith("transformers_llama")
|
||||
or name.startswith(("transformers_vit", "transformers_blip2"))
|
||||
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
|
||||
("transformers_vit", "transformers_blip2")
|
||||
):
|
||||
continue
|
||||
check_lazy_init(entry, verbose=True, default_device=default_device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_torchvision_models_lazy_init("torchvision")
|
||||
test_torchvision_models_lazy_init("transformers", "cpu")
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lazy_init_utils import SUPPORT_LAZY
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
|
||||
|
||||
@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0")
|
||||
def test_lazy_ops():
|
||||
with LazyInitContext():
|
||||
x = torch.rand(2, 3)
|
||||
assert tuple(x.shape) == (2, 3)
|
||||
assert x.device.type == "cpu"
|
||||
x.requires_grad is False
|
||||
y = x.cuda()
|
||||
assert tuple(y.shape) == (2, 3)
|
||||
assert y.device.type == "cuda"
|
||||
assert y.requires_grad is False
|
||||
assert x.cpu() is x
|
||||
p = Parameter(torch.empty(2, 3))
|
||||
assert tuple(p.shape) == (2, 3)
|
||||
assert p.device.type == "cpu"
|
||||
assert p.requires_grad is True
|
||||
assert isinstance(p, Parameter)
|
||||
x.materialize()
|
||||
assert tuple(x.shape) == (2, 3)
|
||||
assert x.device.type == "cpu"
|
||||
assert x.requires_grad is False
|
||||
y.materialize()
|
||||
assert tuple(y.shape) == (2, 3)
|
||||
assert y.device.type == "cuda"
|
||||
assert y.requires_grad is False
|
||||
p.materialize()
|
||||
assert tuple(p.shape) == (2, 3)
|
||||
assert p.device.type == "cpu"
|
||||
assert p.requires_grad is True
|
||||
assert isinstance(p, Parameter)
|
||||
|
||||
with LazyInitContext():
|
||||
x = torch.empty(2, 3)
|
||||
x.uniform_()
|
||||
x.materialize()
|
||||
assert tuple(x.shape) == (2, 3)
|
||||
|
||||
with LazyInitContext():
|
||||
model = nn.Linear(3, 4)
|
||||
model = model.cuda()
|
||||
model_copied = copy.deepcopy(model)
|
||||
LazyInitContext.materialize(model)
|
||||
assert model.weight.device.type == "cuda"
|
||||
assert model.bias.device.type == "cuda"
|
||||
LazyInitContext.materialize(model_copied)
|
||||
assert model_copied.weight.device.type == "cuda"
|
||||
assert model_copied.bias.device.type == "cuda"
|
||||
assert torch.equal(model.weight, model_copied.weight)
|
||||
assert torch.equal(model.bias, model_copied.bias)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lazy_ops()
|
Loading…
Reference in New Issue