[Gemini] add bert for MemtracerWrapper unintests (#1982)

pull/1984/head
Jiarui Fang 2022-11-18 14:58:28 +08:00 committed by GitHub
parent e481489aa6
commit 3712ac7f90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 11 deletions

View File

@ -28,6 +28,9 @@ class _Wrapper():
def show_mem_stats(self): def show_mem_stats(self):
self._ophook_list[0].show_mem_stats() self._ophook_list[0].show_mem_stats()
def named_buffers(self):
return self._model.named_buffers()
def MemtracerWrapper(model): def MemtracerWrapper(model):
ophook_list = [MemTracerOpHook()] ophook_list = [MemTracerOpHook()]

View File

@ -7,6 +7,7 @@ from colossalai.gemini.ophooks import BaseOpHook
class MemTracerOpHook(BaseOpHook): class MemTracerOpHook(BaseOpHook):
""" """
TODO() what if parameters are sharded by multiple submodules. TODO() what if parameters are sharded by multiple submodules.
register buff on its father node
""" """
def __init__(self): def __init__(self):

View File

@ -1,8 +1,13 @@
from functools import partial
import pytest
import torch import torch
import torch.nn as nn import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.gemini.memory_tracer import MemtracerWrapper from colossalai.gemini.memory_tracer import MemtracerWrapper
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
@ -17,16 +22,20 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
model.backward(loss) model.backward(loss)
def test_tracer(): def run_tracer(rank, world_size, port, grad_check=True):
# reset the manager, in case that there exists memory information left colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module'] test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'bert']
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
# init model on cpu # init model on cpu
model = MemtracerWrapper(model_builder()) # TODO() memtrace hook can not handle buff registered on a non-leaf module (for example the BertEmbedding).
# a simple method is that always puts buff on cuda and viewed them as non-model data.
model = MemtracerWrapper(model_builder(grad_check))
for n, buff in model.named_buffers():
buff.data = buff.data.cuda()
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 1: if i > 1:
break break
@ -38,5 +47,13 @@ def test_tracer():
# model._ophook_list[0].print_non_model_data() # model._ophook_list[0].print_non_model_data()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1])
@rerun_if_address_is_in_use()
def test_tracer(world_size):
run_func = partial(run_tracer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_tracer() test_tracer(1)

View File

@ -3,21 +3,21 @@
from functools import partial from functools import partial
import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from common import CONFIG, check_grads_padding, run_fwd_bwd
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [True]) @parameterize("enable_autocast", [True])