diff --git a/colossalai/utils/model/lazy_init_context.py b/colossalai/utils/model/lazy_init_context.py index ed94429d4..cf05f9660 100644 --- a/colossalai/utils/model/lazy_init_context.py +++ b/colossalai/utils/model/lazy_init_context.py @@ -1,23 +1,24 @@ #!/usr/bin/env python # coding: utf-8 +import inspect +import types +from typing import Callable, List + import torch import torch.nn as nn -from colossalai.tensor import ColoParameter, ColoTensor -import types -import inspect -from typing import List, Callable +from colossalai.tensor import ColoParameter, ColoTensor from colossalai.utils.model.utils import substitute_init_recursively class LazyInitContext(): """ - A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor + A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor initialization functions for lazy initialization Note: - This API is only experimental and subject to future changes. + This API is only experimental and subject to future changes. Usage: with LazyInitContext() as ctx: @@ -30,19 +31,20 @@ class LazyInitContext(): # initialize weights ctx.lazy_init_parameters(model) - # make sure the weight is not a meta tensor + # make sure the weight is not a meta tensor # and initialized correctly assert not model.weight.is_meta and torch.all(model.weight == 0) Args: - to_meta (bool): optional, whether to initialize the model with meta tensors, default is False. - extra_torch_tensor_func (List[str]): extra torch tensor functions related + to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This + argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet. + extra_torch_tensor_func (List[str]): extra torch tensor functions related to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. """ tensor_set_value_func = ['zero_', 'fill_'] - def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None): + def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None): # TODO: hijack the torch constructor functions as well self._to_meta = to_meta self._intercepted_nn_init_func_cache = {} @@ -212,18 +214,19 @@ class LazyInitContext(): materialized_tensor = torch.empty_like(tensor, device=device) # if this tensor is a meta tensor, it must have an init function assert tensor in self._intercepted_nn_init_func_cache - tensor = materialized_tensor + else: + materialized_tensor = tensor # apply init function if tensor in self._intercepted_nn_init_func_cache: init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] - init_func(tensor, *args, **kwargs) + init_func(materialized_tensor, *args, **kwargs) # convert it to ColoTensor or ColoParameter if is_param: - tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad) + tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad) else: - tensor = ColoTensor.from_torch_tensor(tensor) + tensor = ColoTensor.from_torch_tensor(materialized_tensor) # override the original tensor with torch.no_grad(): diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py index b17f2cdb6..1d51e0a52 100644 --- a/tests/test_fx/test_complete_workflow.py +++ b/tests/test_fx/test_complete_workflow.py @@ -1,16 +1,18 @@ -import colossalai -import torch -import torch.nn as nn +from functools import partial + import pytest -import torch.multiprocessing as mp +import torch import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from functools import partial +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai from colossalai.fx import ColoTracer -from colossalai.utils.model.lazy_init_context import LazyInitContext from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass -from colossalai.utils import free_port from colossalai.tensor import ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.model.lazy_init_context import LazyInitContext class MLP(torch.nn.Module): @@ -35,6 +37,9 @@ def run_workflow(world_size): with LazyInitContext() as ctx: model = MLP(16) + for param in model.parameters(): + assert param.is_meta + # tracing tracer = ColoTracer() graph = tracer.trace(model) @@ -46,6 +51,8 @@ def run_workflow(world_size): # materialization and sharding ctx.lazy_init_parameters(annotated_gm) + for param in model.parameters(): + assert not param.is_meta # # check sharding assert list(model.linear1.weight.shape) == [16 // world_size, 16] @@ -57,7 +64,7 @@ def run_workflow(world_size): data = torch.rand(4, 16) non_fx_out = model(data) fx_out = annotated_gm(data) - assert torch.equal(non_fx_out, fx_out) + assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}' def run_dist(rank, world_size, port): @@ -74,4 +81,4 @@ def test_complete_workflow(world_size): if __name__ == '__main__': - test_complete_workflow(2) + test_complete_workflow(1)