Browse Source

[utils] fixed lazy init context (#1867)

pull/1874/head
Frank Lee 2 years ago committed by GitHub
parent
commit
e6ec99d389
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 31
      colossalai/utils/model/lazy_init_context.py
  2. 27
      tests/test_fx/test_complete_workflow.py

31
colossalai/utils/model/lazy_init_context.py

@ -1,23 +1,24 @@
#!/usr/bin/env python
# coding: utf-8
import inspect
import types
from typing import Callable, List
import torch
import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
import types
import inspect
from typing import List, Callable
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.utils.model.utils import substitute_init_recursively
class LazyInitContext():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
This API is only experimental and subject to future changes.
Usage:
with LazyInitContext() as ctx:
@ -30,19 +31,20 @@ class LazyInitContext():
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None):
def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
# TODO: hijack the torch constructor functions as well
self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {}
@ -212,18 +214,19 @@ class LazyInitContext():
materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache
tensor = materialized_tensor
else:
materialized_tensor = tensor
# apply init function
if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(tensor, *args, **kwargs)
init_func(materialized_tensor, *args, **kwargs)
# convert it to ColoTensor or ColoParameter
if is_param:
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad)
tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
else:
tensor = ColoTensor.from_torch_tensor(tensor)
tensor = ColoTensor.from_torch_tensor(materialized_tensor)
# override the original tensor
with torch.no_grad():

27
tests/test_fx/test_complete_workflow.py

@ -1,16 +1,18 @@
import colossalai
import torch
import torch.nn as nn
from functools import partial
import pytest
import torch.multiprocessing as mp
import torch
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
import torch.multiprocessing as mp
import torch.nn as nn
import colossalai
from colossalai.fx import ColoTracer
from colossalai.utils.model.lazy_init_context import LazyInitContext
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.model.lazy_init_context import LazyInitContext
class MLP(torch.nn.Module):
@ -35,6 +37,9 @@ def run_workflow(world_size):
with LazyInitContext() as ctx:
model = MLP(16)
for param in model.parameters():
assert param.is_meta
# tracing
tracer = ColoTracer()
graph = tracer.trace(model)
@ -46,6 +51,8 @@ def run_workflow(world_size):
# materialization and sharding
ctx.lazy_init_parameters(annotated_gm)
for param in model.parameters():
assert not param.is_meta
# # check sharding
assert list(model.linear1.weight.shape) == [16 // world_size, 16]
@ -57,7 +64,7 @@ def run_workflow(world_size):
data = torch.rand(4, 16)
non_fx_out = model(data)
fx_out = annotated_gm(data)
assert torch.equal(non_fx_out, fx_out)
assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}'
def run_dist(rank, world_size, port):
@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
if __name__ == '__main__':
test_complete_workflow(2)
test_complete_workflow(1)

Loading…
Cancel
Save