mirror of https://github.com/hpcaitech/ColossalAI
[polish] use GLOBAL_MODEL_DATA_TRACER (#417)
parent
23ba3fc450
commit
56bb412e72
|
@ -5,7 +5,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
|||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Optional
|
||||
|
||||
|
||||
|
@ -25,7 +25,6 @@ class ZeroHook(BaseOpHook):
|
|||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
global_model_data_tracer = ModelDataTracer()
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
|
@ -33,7 +32,7 @@ class ZeroHook(BaseOpHook):
|
|||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
|
@ -50,7 +49,6 @@ class ZeroHook(BaseOpHook):
|
|||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
tensor_list = []
|
||||
global_model_data_tracer = ModelDataTracer()
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
|
@ -58,7 +56,7 @@ class ZeroHook(BaseOpHook):
|
|||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
def col_move_to_cpu(t: torch.Tensor):
|
||||
|
@ -7,7 +7,7 @@ def col_move_to_cpu(t: torch.Tensor):
|
|||
if t.device.type == 'cpu':
|
||||
return
|
||||
|
||||
ModelDataTracer().delete_tensor(t)
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
|
||||
t.data = t.data.cpu()
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from .async_memtracer import get_cuda_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -54,7 +54,7 @@ class MemStatsCollector:
|
|||
if self._start_flag:
|
||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||
assert sampling_cnt == len(self._overall_cuda)
|
||||
self._model_data_cuda.append(ModelDataTracer().cuda_usage)
|
||||
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
|
|
|
@ -5,10 +5,9 @@ import torch
|
|||
|
||||
class ModelDataTracer(metaclass=SingletonMeta):
|
||||
"""
|
||||
A singleton to trace model data usage during runtime.
|
||||
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
|
||||
including allocation, releasing and moving.
|
||||
|
||||
A tracer singleton to trace model data usage during runtime.
|
||||
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
|
||||
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
|
||||
NOTE() now the class only trace cuda memory usage
|
||||
"""
|
||||
|
||||
|
@ -32,3 +31,6 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||
@property
|
||||
def cuda_usage(self):
|
||||
return self._cuda_usage
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ def test_mem_collector():
|
|||
collector.sample_memstats()
|
||||
|
||||
m_a = torch.randn(10).cuda()
|
||||
ModelDataTracer().add_tensor(m_a)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
|
||||
b = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 1
|
||||
|
@ -26,7 +26,7 @@ def test_mem_collector():
|
|||
collector.sample_memstats()
|
||||
|
||||
collector.finish_collection()
|
||||
collector.reset()
|
||||
collector.reset_sampling_cnter()
|
||||
|
||||
# do nothing after collection, just advance sampling cnter
|
||||
collector.sample_memstats()
|
||||
|
|
|
@ -3,7 +3,7 @@ import functools
|
|||
import torch
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
||||
|
@ -153,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
if param.col_attr.grad and self.shard_grad:
|
||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
|
|
@ -26,15 +26,15 @@ def run_naive_amp():
|
|||
test_models = ['repeated_computed_layers', 'nested_model']
|
||||
for test_name in test_models:
|
||||
get_component_func = non_distributed_component_funcs.get_callable(test_name)
|
||||
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
|
||||
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
|
||||
|
||||
# create model
|
||||
amp_model = model_builder(checkpoint=True).cuda()
|
||||
torch_model = copy.deepcopy(amp_model)
|
||||
|
||||
# create optimizer
|
||||
amp_optimizer = optim_builder(amp_model)
|
||||
torch_optimizer = optim_builder(torch_model)
|
||||
amp_optimizer = optim_class(amp_model.parameters(), lr=1e-3)
|
||||
torch_optimizer = optim_class(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
# inject naive amp
|
||||
amp_config = dict(initial_scale=1)
|
||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
|
|||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||
|
@ -37,10 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
|
|||
assert param.col_attr.data.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
print(f'numel {model_numel_tensor}')
|
||||
if init_device.type == 'cuda':
|
||||
assert (ModelDataTracer().cuda_usage > 0)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue