mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [lazy] support _like methods and clamp * [lazy] pass transformers models * [lazy] fix device move and requires grad * [lazy] fix requires grad and refactor api * [lazy] fix requires gradpull/4773/head^2
Hongxin Liu
1 year ago
committed by
GitHub
5 changed files with 273 additions and 94 deletions
@ -0,0 +1,87 @@
|
||||
from contextlib import contextmanager |
||||
from typing import Callable, Dict, Tuple |
||||
|
||||
import torch |
||||
|
||||
__all__ = [ |
||||
"_LEGACY_TENSOR_CONSTRUCTOR", |
||||
"_NO_META_FACTORY", |
||||
"_NORMAL_FACTORY", |
||||
"ConstructorManager", |
||||
] |
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html |
||||
_NORMAL_FACTORY = [ |
||||
"arange", |
||||
"full", |
||||
"empty", |
||||
"linspace", |
||||
"logspace", |
||||
"ones", |
||||
"rand", |
||||
"randn", |
||||
"randint", |
||||
"randperm", |
||||
"zeros", |
||||
"tensor", |
||||
] |
||||
|
||||
# factory function that does not support meta tensor backend |
||||
_NO_META_FACTORY = [ |
||||
"eye", |
||||
] |
||||
|
||||
_LEGACY_TENSOR_CONSTRUCTOR = { |
||||
"FloatTensor": torch.float, |
||||
"DoubleTensor": torch.double, |
||||
"HalfTensor": torch.half, |
||||
"BFloat16Tensor": torch.bfloat16, |
||||
"ByteTensor": torch.uint8, |
||||
"CharTensor": torch.int8, |
||||
"ShortTensor": torch.short, |
||||
"IntTensor": torch.int, |
||||
"LongTensor": torch.long, |
||||
"BoolTensor": torch.bool, |
||||
} |
||||
|
||||
|
||||
class ConstructorManager: |
||||
# function name: (new, old) |
||||
overwrites: Dict[str, Tuple[Callable, Callable]] = {} |
||||
changed: bool = False |
||||
|
||||
@staticmethod |
||||
def apply(overwrites: Dict[Callable, Callable]): |
||||
ConstructorManager.overwrites.clear() |
||||
ConstructorManager.overwrites.update(overwrites) |
||||
ConstructorManager.redo() |
||||
|
||||
@staticmethod |
||||
def undo(): |
||||
assert ConstructorManager.changed, "No constructor change to undo" |
||||
for name, (new, old) in ConstructorManager.overwrites.items(): |
||||
setattr(torch, name, old) |
||||
ConstructorManager.changed = False |
||||
|
||||
@staticmethod |
||||
def redo(): |
||||
assert not ConstructorManager.changed, "Constructor already changed" |
||||
for name, (new, old) in ConstructorManager.overwrites.items(): |
||||
setattr(torch, name, new) |
||||
ConstructorManager.changed = True |
||||
|
||||
@staticmethod |
||||
@contextmanager |
||||
def disable(): |
||||
enabled = ConstructorManager.changed |
||||
if enabled: |
||||
ConstructorManager.undo() |
||||
yield |
||||
if enabled: |
||||
ConstructorManager.redo() |
||||
|
||||
@staticmethod |
||||
def clear(): |
||||
if ConstructorManager.changed: |
||||
ConstructorManager.undo() |
||||
ConstructorManager.overwrites.clear() |
@ -0,0 +1,64 @@
|
||||
import copy |
||||
|
||||
import pytest |
||||
import torch |
||||
import torch.nn as nn |
||||
from lazy_init_utils import SUPPORT_LAZY |
||||
from torch.nn import Parameter |
||||
|
||||
from colossalai.lazy import LazyInitContext |
||||
|
||||
|
||||
@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") |
||||
def test_lazy_ops(): |
||||
with LazyInitContext(): |
||||
x = torch.rand(2, 3) |
||||
assert tuple(x.shape) == (2, 3) |
||||
assert x.device.type == "cpu" |
||||
x.requires_grad is False |
||||
y = x.cuda() |
||||
assert tuple(y.shape) == (2, 3) |
||||
assert y.device.type == "cuda" |
||||
assert y.requires_grad is False |
||||
assert x.cpu() is x |
||||
p = Parameter(torch.empty(2, 3)) |
||||
assert tuple(p.shape) == (2, 3) |
||||
assert p.device.type == "cpu" |
||||
assert p.requires_grad is True |
||||
assert isinstance(p, Parameter) |
||||
x.materialize() |
||||
assert tuple(x.shape) == (2, 3) |
||||
assert x.device.type == "cpu" |
||||
assert x.requires_grad is False |
||||
y.materialize() |
||||
assert tuple(y.shape) == (2, 3) |
||||
assert y.device.type == "cuda" |
||||
assert y.requires_grad is False |
||||
p.materialize() |
||||
assert tuple(p.shape) == (2, 3) |
||||
assert p.device.type == "cpu" |
||||
assert p.requires_grad is True |
||||
assert isinstance(p, Parameter) |
||||
|
||||
with LazyInitContext(): |
||||
x = torch.empty(2, 3) |
||||
x.uniform_() |
||||
x.materialize() |
||||
assert tuple(x.shape) == (2, 3) |
||||
|
||||
with LazyInitContext(): |
||||
model = nn.Linear(3, 4) |
||||
model = model.cuda() |
||||
model_copied = copy.deepcopy(model) |
||||
LazyInitContext.materialize(model) |
||||
assert model.weight.device.type == "cuda" |
||||
assert model.bias.device.type == "cuda" |
||||
LazyInitContext.materialize(model_copied) |
||||
assert model_copied.weight.device.type == "cuda" |
||||
assert model_copied.bias.device.type == "cuda" |
||||
assert torch.equal(model.weight, model_copied.weight) |
||||
assert torch.equal(model.bias, model_copied.bias) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_lazy_ops() |
Loading…
Reference in new issue