[polish] use GLOBAL_MODEL_DATA_TRACER (#417)

pull/418/head^2
Jiarui Fang 2022-03-15 11:29:46 +08:00 committed by GitHub
parent 23ba3fc450
commit 56bb412e72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 25 additions and 25 deletions

View File

@ -5,7 +5,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Optional from typing import Optional
@ -25,7 +25,6 @@ class ZeroHook(BaseOpHook):
def pre_fwd_exec(self, module: torch.nn.Module, *args): def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = [] tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data) tensor_list.append(param.col_attr.data)
@ -33,7 +32,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
if self._memstarts_collector: if self._memstarts_collector:
@ -50,7 +49,6 @@ class ZeroHook(BaseOpHook):
def pre_bwd_exec(self, module: torch.nn.Module, input, output): def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = [] tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data) tensor_list.append(param.col_attr.data)
@ -58,7 +56,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(): for param in module.parameters():
if param.col_attr.data.device != self.computing_device: if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device) param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload param.data = param.col_attr.data.payload
# Store local accumulated grad shard # Store local accumulated grad shard
if param.grad is not None: if param.grad is not None:

View File

@ -1,5 +1,5 @@
import torch import torch
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
def col_move_to_cpu(t: torch.Tensor): def col_move_to_cpu(t: torch.Tensor):
@ -7,7 +7,7 @@ def col_move_to_cpu(t: torch.Tensor):
if t.device.type == 'cpu': if t.device.type == 'cpu':
return return
ModelDataTracer().delete_tensor(t) GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
t.data = t.data.cpu() t.data = t.data.cpu()

View File

@ -1,4 +1,4 @@
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from .async_memtracer import get_cuda_memory_used from .async_memtracer import get_cuda_memory_used
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -54,7 +54,7 @@ class MemStatsCollector:
if self._start_flag: if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt sampling_cnt = self._sampling_cnter.sampling_cnt
assert sampling_cnt == len(self._overall_cuda) assert sampling_cnt == len(self._overall_cuda)
self._model_data_cuda.append(ModelDataTracer().cuda_usage) self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}'))) self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
self._sampling_cnter.advance() self._sampling_cnter.advance()

View File

@ -5,10 +5,9 @@ import torch
class ModelDataTracer(metaclass=SingletonMeta): class ModelDataTracer(metaclass=SingletonMeta):
""" """
A singleton to trace model data usage during runtime. A tracer singleton to trace model data usage during runtime.
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation, The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
including allocation, releasing and moving. To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
NOTE() now the class only trace cuda memory usage NOTE() now the class only trace cuda memory usage
""" """
@ -32,3 +31,6 @@ class ModelDataTracer(metaclass=SingletonMeta):
@property @property
def cuda_usage(self): def cuda_usage(self):
return self._cuda_usage return self._cuda_usage
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()

View File

@ -1,5 +1,5 @@
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
import torch import torch
@ -14,7 +14,7 @@ def test_mem_collector():
collector.sample_memstats() collector.sample_memstats()
m_a = torch.randn(10).cuda() m_a = torch.randn(10).cuda()
ModelDataTracer().add_tensor(m_a) GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
b = torch.randn(10).cuda() b = torch.randn(10).cuda()
# sampling at time 1 # sampling at time 1
@ -26,7 +26,7 @@ def test_mem_collector():
collector.sample_memstats() collector.sample_memstats()
collector.finish_collection() collector.finish_collection()
collector.reset() collector.reset_sampling_cnter()
# do nothing after collection, just advance sampling cnter # do nothing after collection, just advance sampling cnter
collector.sample_memstats() collector.sample_memstats()

View File

@ -3,7 +3,7 @@ import functools
import torch import torch
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method # Inserts _post_init_method at the end of init method
@ -153,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param: if self.shard_param:
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
if param.col_attr.grad and self.shard_grad: if param.col_attr.grad and self.shard_grad:
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)

View File

@ -26,15 +26,15 @@ def run_naive_amp():
test_models = ['repeated_computed_layers', 'nested_model'] test_models = ['repeated_computed_layers', 'nested_model']
for test_name in test_models: for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name) get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_builder, _ = get_component_func() model_builder, train_dataloader, _, optim_class, _ = get_component_func()
# create model # create model
amp_model = model_builder(checkpoint=True).cuda() amp_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(amp_model) torch_model = copy.deepcopy(amp_model)
# create optimizer # create optimizer
amp_optimizer = optim_builder(amp_model) amp_optimizer = optim_class(amp_model.parameters(), lr=1e-3)
torch_optimizer = optim_builder(torch_model) torch_optimizer = optim_class(torch_model.parameters(), lr=1e-3)
# inject naive amp # inject naive amp
amp_config = dict(initial_scale=1) amp_config = dict(initial_scale=1)

View File

@ -14,7 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
def run_dist(rank, world_size, port, init_device, shard_strategy): def run_dist(rank, world_size, port, init_device, shard_strategy):
@ -37,10 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
assert param.col_attr.data.payload.device.type == init_device.type, \ assert param.col_attr.data.payload.device.type == init_device.type, \
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}' f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
print(f'cuda usgae {ModelDataTracer().cuda_usage}') print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
print(f'numel {model_numel_tensor}') print(f'numel {model_numel_tensor}')
if init_device.type == 'cuda': if init_device.type == 'cuda':
assert (ModelDataTracer().cuda_usage > 0) assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
@pytest.mark.dist @pytest.mark.dist