Browse Source

[utils] fixed lazy init context (#1867)

pull/1874/head
Frank Lee 2 years ago committed by GitHub
parent
commit
e6ec99d389
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 23
      colossalai/utils/model/lazy_init_context.py
  2. 27
      tests/test_fx/test_complete_workflow.py

23
colossalai/utils/model/lazy_init_context.py

@ -1,13 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
import inspect
import types
from typing import Callable, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
import types from colossalai.tensor import ColoParameter, ColoTensor
import inspect
from typing import List, Callable
from colossalai.utils.model.utils import substitute_init_recursively from colossalai.utils.model.utils import substitute_init_recursively
@ -35,14 +36,15 @@ class LazyInitContext():
assert not model.weight.is_meta and torch.all(model.weight == 0) assert not model.weight.is_meta and torch.all(model.weight == 0)
Args: Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False. to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
""" """
tensor_set_value_func = ['zero_', 'fill_'] tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None): def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
# TODO: hijack the torch constructor functions as well # TODO: hijack the torch constructor functions as well
self._to_meta = to_meta self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {} self._intercepted_nn_init_func_cache = {}
@ -212,18 +214,19 @@ class LazyInitContext():
materialized_tensor = torch.empty_like(tensor, device=device) materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function # if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache assert tensor in self._intercepted_nn_init_func_cache
tensor = materialized_tensor else:
materialized_tensor = tensor
# apply init function # apply init function
if tensor in self._intercepted_nn_init_func_cache: if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(tensor, *args, **kwargs) init_func(materialized_tensor, *args, **kwargs)
# convert it to ColoTensor or ColoParameter # convert it to ColoTensor or ColoParameter
if is_param: if is_param:
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad) tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
else: else:
tensor = ColoTensor.from_torch_tensor(tensor) tensor = ColoTensor.from_torch_tensor(materialized_tensor)
# override the original tensor # override the original tensor
with torch.no_grad(): with torch.no_grad():

27
tests/test_fx/test_complete_workflow.py

@ -1,16 +1,18 @@
import colossalai from functools import partial
import torch
import torch.nn as nn
import pytest import pytest
import torch.multiprocessing as mp import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use import torch.multiprocessing as mp
from functools import partial import torch.nn as nn
import colossalai
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.utils.model.lazy_init_context import LazyInitContext
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.model.lazy_init_context import LazyInitContext
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
@ -35,6 +37,9 @@ def run_workflow(world_size):
with LazyInitContext() as ctx: with LazyInitContext() as ctx:
model = MLP(16) model = MLP(16)
for param in model.parameters():
assert param.is_meta
# tracing # tracing
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model) graph = tracer.trace(model)
@ -46,6 +51,8 @@ def run_workflow(world_size):
# materialization and sharding # materialization and sharding
ctx.lazy_init_parameters(annotated_gm) ctx.lazy_init_parameters(annotated_gm)
for param in model.parameters():
assert not param.is_meta
# # check sharding # # check sharding
assert list(model.linear1.weight.shape) == [16 // world_size, 16] assert list(model.linear1.weight.shape) == [16 // world_size, 16]
@ -57,7 +64,7 @@ def run_workflow(world_size):
data = torch.rand(4, 16) data = torch.rand(4, 16)
non_fx_out = model(data) non_fx_out = model(data)
fx_out = annotated_gm(data) fx_out = annotated_gm(data)
assert torch.equal(non_fx_out, fx_out) assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}'
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_complete_workflow(2) test_complete_workflow(1)

Loading…
Cancel
Save