From 3ce4463fe6c5bd4c6452b93eabd20dc591852272 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 9 Nov 2022 11:50:33 +0800 Subject: [PATCH] [utils] remove lazy_memory_allocate from ColoInitContext (#1844) --- colossalai/utils/model/colo_init_context.py | 24 ++++++------- tests/test_tensor/model/test_model.py | 25 ++++++++------ tests/test_tensor/model/test_module_spec.py | 30 +++++++++-------- tests/test_tensor/test_context.py | 37 +-------------------- 4 files changed, 44 insertions(+), 72 deletions(-) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 3824d27f6..95e9d4090 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -1,10 +1,13 @@ -from .utils import InsertPostInitMethodToModuleSubClasses +from typing import Iterator, Tuple, Union + import torch -from colossalai.tensor import ColoTensor, ColoParameter -from colossalai.nn.parallel.layers import register_colo_module, \ - ColoLinear, ColoEmbedding from torch import nn -from typing import Iterator, Tuple, Union + +from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module +from colossalai.tensor import ColoParameter, ColoTensor + +from .utils import InsertPostInitMethodToModuleSubClasses + # find named_params includes replica @@ -33,17 +36,13 @@ def ColoModulize(module): class ColoInitContext(InsertPostInitMethodToModuleSubClasses): - def __init__(self, - lazy_memory_allocate: bool = False, - device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float): + def __init__(self, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float): """ Args: - lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False. - device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu'). + device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). + dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float. """ super().__init__() - self._lazy_memory_allocate = lazy_memory_allocate self._device = device self._dtype = dtype @@ -87,7 +86,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): if param in replaced_tensors: colo_param = replaced_tensors[param] else: - save_torch_payload = True if not self._lazy_memory_allocate else False # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad # TODO(jiaruifang) we initialize a Default PG memory diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py index c50393467..361fef8aa 100644 --- a/tests/test_tensor/model/test_model.py +++ b/tests/test_tensor/model/test_model.py @@ -1,20 +1,25 @@ -import pytest from functools import partial + +import pytest import torch import torch.multiprocessing as mp -from colossalai.tensor.colo_parameter import ColoParameter import colossalai +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoTensor, ProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ColoTensor, ProcessGroup -from colossalai.nn.optimizer import ColossalaiOptimizer - from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import tensor_shard_equal, check_equal, set_seed, \ - split_param_row_tp1d, split_param_col_tp1d +from tests.test_tensor.common_utils import ( + check_equal, + set_seed, + split_param_col_tp1d, + split_param_row_tp1d, + tensor_shard_equal, +) def run_1d_hybrid_tp(model_name): @@ -169,7 +174,7 @@ def test_colo_optimizer(): get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() set_seed(1) - with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) @@ -266,7 +271,7 @@ def _run_pretrain_load(): from transformers import BertForMaskedLM set_seed(1) model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') - with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + with ColoInitContext(device=get_current_device()): model = BertForMaskedLM.from_pretrained('bert-base-uncased') model_pretrained = model_pretrained.cuda() diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py index a3eda1d8a..997b416f1 100644 --- a/tests/test_tensor/model/test_module_spec.py +++ b/tests/test_tensor/model/test_module_spec.py @@ -1,24 +1,28 @@ from copy import deepcopy -import pytest from functools import partial +import pytest import torch import torch.multiprocessing as mp -from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec -from colossalai.nn.parallel.layers import init_colo_module, check_colo_module -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed - import colossalai -from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext - -from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec - +from colossalai.nn.parallel.layers import check_colo_module, init_colo_module +from colossalai.tensor import ( + ColoTensor, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ReplicaSpec, + ShardSpec, + distspec, +) from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port - +from colossalai.utils.cuda import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal def run_model_with_spec(mode, model_name): @@ -134,7 +138,7 @@ def run_linear_with_spec(mode): def run_check_shared_param(): - from transformers import BertForMaskedLM, BertConfig + from transformers import BertConfig, BertForMaskedLM hidden_dim = 8 num_head = 4 sequence_length = 12 @@ -153,7 +157,7 @@ def run_check_shared_param(): num_hidden_layers=num_layer, hidden_dropout_prob=0., attention_probs_dropout_prob=0.) - with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): + with ColoInitContext(device=get_current_device()): model = BertForMaskedLM(config) model = model.cuda() diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 8171ebfab..0dc9b8c49 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -1,40 +1,5 @@ import pytest -from colossalai.utils.model.colo_init_context import ColoInitContext - import torch from colossalai.utils.cuda import get_current_device - - -@pytest.mark.skip -# FIXME(ver217): support lazy init -def test_lazy_init(): - in_dim = 4 - out_dim = 5 - - with ColoInitContext(lazy_memory_allocate=True) as ctx: - fc = torch.nn.Linear(in_dim, out_dim, bias=True) - - # lazy_memory_allocate=True, no payload is maintained - assert fc.weight._torch_tensor.numel() == 0 - - fc.weight.torch_tensor() - assert fc.weight._torch_tensor.numel() == in_dim * out_dim - - -@pytest.mark.skip -def test_device(): - in_dim = 4 - out_dim = 5 - - with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx: - fc = torch.nn.Linear(in_dim, out_dim, bias=True) - - # eval an lazy parameter - fc.weight.torch_tensor() - assert fc.weight.device == get_current_device() - - -if __name__ == '__main__': - test_lazy_init() - test_device() +from colossalai.utils.model.colo_init_context import ColoInitContext