Browse Source

[utils] remove lazy_memory_allocate from ColoInitContext (#1844)

pull/1849/head
Jiarui Fang 2 years ago committed by GitHub
parent
commit
3ce4463fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 24
      colossalai/utils/model/colo_init_context.py
  2. 25
      tests/test_tensor/model/test_model.py
  3. 30
      tests/test_tensor/model/test_module_spec.py
  4. 37
      tests/test_tensor/test_context.py

24
colossalai/utils/model/colo_init_context.py

@ -1,10 +1,13 @@
from .utils import InsertPostInitMethodToModuleSubClasses from typing import Iterator, Tuple, Union
import torch import torch
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
from torch import nn from torch import nn
from typing import Iterator, Tuple, Union
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor
from .utils import InsertPostInitMethodToModuleSubClasses
# find named_params includes replica # find named_params includes replica
@ -33,17 +36,13 @@ def ColoModulize(module):
class ColoInitContext(InsertPostInitMethodToModuleSubClasses): class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, def __init__(self, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float):
lazy_memory_allocate: bool = False,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float):
""" """
Args: Args:
lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False. device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu'). dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float.
""" """
super().__init__() super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device self._device = device
self._dtype = dtype self._dtype = dtype
@ -87,7 +86,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if param in replaced_tensors: if param in replaced_tensors:
colo_param = replaced_tensors[param] colo_param = replaced_tensors[param]
else: else:
save_torch_payload = True if not self._lazy_memory_allocate else False
# detaching tensor is necessary for optimizers. # detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad requires_grad = param.requires_grad
# TODO(jiaruifang) we initialize a Default PG memory # TODO(jiaruifang) we initialize a Default PG memory

25
tests/test_tensor/model/test_model.py

@ -1,20 +1,25 @@
import pytest
from functools import partial from functools import partial
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor.colo_parameter import ColoParameter
import colossalai import colossalai
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.nn.optimizer import ColossalaiOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import tensor_shard_equal, check_equal, set_seed, \ from tests.test_tensor.common_utils import (
split_param_row_tp1d, split_param_col_tp1d check_equal,
set_seed,
split_param_col_tp1d,
split_param_row_tp1d,
tensor_shard_equal,
)
def run_1d_hybrid_tp(model_name): def run_1d_hybrid_tp(model_name):
@ -169,7 +174,7 @@ def test_colo_optimizer():
get_components_func = non_distributed_component_funcs.get_callable('simple_net') get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(1) set_seed(1)
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1)) colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
@ -266,7 +271,7 @@ def _run_pretrain_load():
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
set_seed(1) set_seed(1)
model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased') model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = BertForMaskedLM.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model_pretrained = model_pretrained.cuda() model_pretrained = model_pretrained.cuda()

30
tests/test_tensor/model/test_module_spec.py

@ -1,24 +1,28 @@
from copy import deepcopy from copy import deepcopy
import pytest
from functools import partial from functools import partial
import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
import colossalai import colossalai
from colossalai.utils.cuda import get_current_device from colossalai.nn.parallel.layers import check_colo_module, init_colo_module
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import (
ColoTensor,
from colossalai.tensor import distspec, ProcessGroup, ReplicaSpec ColoTensorSpec,
ComputePattern,
ComputeSpec,
ProcessGroup,
ReplicaSpec,
ShardSpec,
distspec,
)
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
def run_model_with_spec(mode, model_name): def run_model_with_spec(mode, model_name):
@ -134,7 +138,7 @@ def run_linear_with_spec(mode):
def run_check_shared_param(): def run_check_shared_param():
from transformers import BertForMaskedLM, BertConfig from transformers import BertConfig, BertForMaskedLM
hidden_dim = 8 hidden_dim = 8
num_head = 4 num_head = 4
sequence_length = 12 sequence_length = 12
@ -153,7 +157,7 @@ def run_check_shared_param():
num_hidden_layers=num_layer, num_hidden_layers=num_layer,
hidden_dropout_prob=0., hidden_dropout_prob=0.,
attention_probs_dropout_prob=0.) attention_probs_dropout_prob=0.)
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = BertForMaskedLM(config) model = BertForMaskedLM(config)
model = model.cuda() model = model.cuda()

37
tests/test_tensor/test_context.py

@ -1,40 +1,5 @@
import pytest import pytest
from colossalai.utils.model.colo_init_context import ColoInitContext
import torch import torch
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
@pytest.mark.skip
# FIXME(ver217): support lazy init
def test_lazy_init():
in_dim = 4
out_dim = 5
with ColoInitContext(lazy_memory_allocate=True) as ctx:
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
# lazy_memory_allocate=True, no payload is maintained
assert fc.weight._torch_tensor.numel() == 0
fc.weight.torch_tensor()
assert fc.weight._torch_tensor.numel() == in_dim * out_dim
@pytest.mark.skip
def test_device():
in_dim = 4
out_dim = 5
with ColoInitContext(lazy_memory_allocate=True, device=get_current_device()) as ctx:
fc = torch.nn.Linear(in_dim, out_dim, bias=True)
# eval an lazy parameter
fc.weight.torch_tensor()
assert fc.weight.device == get_current_device()
if __name__ == '__main__':
test_lazy_init()
test_device()

Loading…
Cancel
Save