mirror of https://github.com/hpcaitech/ColossalAI
[polish] use GLOBAL_MODEL_DATA_TRACER (#417)
parent
23ba3fc450
commit
56bb412e72
|
@ -5,7 +5,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
|
|
||||||
from ._base_ophook import BaseOpHook
|
from ._base_ophook import BaseOpHook
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ class ZeroHook(BaseOpHook):
|
||||||
|
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
global_model_data_tracer = ModelDataTracer()
|
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
tensor_list.append(param.col_attr.data)
|
tensor_list.append(param.col_attr.data)
|
||||||
|
@ -33,7 +32,7 @@ class ZeroHook(BaseOpHook):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
if param.col_attr.data.device != self.computing_device:
|
if param.col_attr.data.device != self.computing_device:
|
||||||
param.col_attr.data.to(self.computing_device)
|
param.col_attr.data.to(self.computing_device)
|
||||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||||
param.data = param.col_attr.data.payload
|
param.data = param.col_attr.data.payload
|
||||||
|
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
|
@ -50,7 +49,6 @@ class ZeroHook(BaseOpHook):
|
||||||
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
global_model_data_tracer = ModelDataTracer()
|
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, 'col_attr')
|
assert hasattr(param, 'col_attr')
|
||||||
tensor_list.append(param.col_attr.data)
|
tensor_list.append(param.col_attr.data)
|
||||||
|
@ -58,7 +56,7 @@ class ZeroHook(BaseOpHook):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
if param.col_attr.data.device != self.computing_device:
|
if param.col_attr.data.device != self.computing_device:
|
||||||
param.col_attr.data.to(self.computing_device)
|
param.col_attr.data.to(self.computing_device)
|
||||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||||
param.data = param.col_attr.data.payload
|
param.data = param.col_attr.data.payload
|
||||||
# Store local accumulated grad shard
|
# Store local accumulated grad shard
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
|
|
||||||
def col_move_to_cpu(t: torch.Tensor):
|
def col_move_to_cpu(t: torch.Tensor):
|
||||||
|
@ -7,7 +7,7 @@ def col_move_to_cpu(t: torch.Tensor):
|
||||||
if t.device.type == 'cpu':
|
if t.device.type == 'cpu':
|
||||||
return
|
return
|
||||||
|
|
||||||
ModelDataTracer().delete_tensor(t)
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
|
||||||
t.data = t.data.cpu()
|
t.data = t.data.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from .async_memtracer import get_cuda_memory_used
|
from .async_memtracer import get_cuda_memory_used
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ class MemStatsCollector:
|
||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||||
assert sampling_cnt == len(self._overall_cuda)
|
assert sampling_cnt == len(self._overall_cuda)
|
||||||
self._model_data_cuda.append(ModelDataTracer().cuda_usage)
|
self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||||
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||||
self._sampling_cnter.advance()
|
self._sampling_cnter.advance()
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,9 @@ import torch
|
||||||
|
|
||||||
class ModelDataTracer(metaclass=SingletonMeta):
|
class ModelDataTracer(metaclass=SingletonMeta):
|
||||||
"""
|
"""
|
||||||
A singleton to trace model data usage during runtime.
|
A tracer singleton to trace model data usage during runtime.
|
||||||
We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation,
|
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
|
||||||
including allocation, releasing and moving.
|
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
|
||||||
|
|
||||||
NOTE() now the class only trace cuda memory usage
|
NOTE() now the class only trace cuda memory usage
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -32,3 +31,6 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
||||||
@property
|
@property
|
||||||
def cuda_usage(self):
|
def cuda_usage(self):
|
||||||
return self._cuda_usage
|
return self._cuda_usage
|
||||||
|
|
||||||
|
|
||||||
|
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ def test_mem_collector():
|
||||||
collector.sample_memstats()
|
collector.sample_memstats()
|
||||||
|
|
||||||
m_a = torch.randn(10).cuda()
|
m_a = torch.randn(10).cuda()
|
||||||
ModelDataTracer().add_tensor(m_a)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
|
||||||
b = torch.randn(10).cuda()
|
b = torch.randn(10).cuda()
|
||||||
|
|
||||||
# sampling at time 1
|
# sampling at time 1
|
||||||
|
@ -26,7 +26,7 @@ def test_mem_collector():
|
||||||
collector.sample_memstats()
|
collector.sample_memstats()
|
||||||
|
|
||||||
collector.finish_collection()
|
collector.finish_collection()
|
||||||
collector.reset()
|
collector.reset_sampling_cnter()
|
||||||
|
|
||||||
# do nothing after collection, just advance sampling cnter
|
# do nothing after collection, just advance sampling cnter
|
||||||
collector.sample_memstats()
|
collector.sample_memstats()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import functools
|
||||||
import torch
|
import torch
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
# Inserts _post_init_method at the end of init method
|
# Inserts _post_init_method at the end of init method
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor])
|
||||||
ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||||
if param.col_attr.grad and self.shard_grad:
|
if param.col_attr.grad and self.shard_grad:
|
||||||
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
|
||||||
ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||||
|
|
|
@ -26,15 +26,15 @@ def run_naive_amp():
|
||||||
test_models = ['repeated_computed_layers', 'nested_model']
|
test_models = ['repeated_computed_layers', 'nested_model']
|
||||||
for test_name in test_models:
|
for test_name in test_models:
|
||||||
get_component_func = non_distributed_component_funcs.get_callable(test_name)
|
get_component_func = non_distributed_component_funcs.get_callable(test_name)
|
||||||
model_builder, train_dataloader, _, optim_builder, _ = get_component_func()
|
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
amp_model = model_builder(checkpoint=True).cuda()
|
amp_model = model_builder(checkpoint=True).cuda()
|
||||||
torch_model = copy.deepcopy(amp_model)
|
torch_model = copy.deepcopy(amp_model)
|
||||||
|
|
||||||
# create optimizer
|
# create optimizer
|
||||||
amp_optimizer = optim_builder(amp_model)
|
amp_optimizer = optim_class(amp_model.parameters(), lr=1e-3)
|
||||||
torch_optimizer = optim_builder(torch_model)
|
torch_optimizer = optim_class(torch_model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
# inject naive amp
|
# inject naive amp
|
||||||
amp_config = dict(initial_scale=1)
|
amp_config = dict(initial_scale=1)
|
||||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
from common import CONFIG
|
from common import CONFIG
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||||
|
@ -37,10 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy):
|
||||||
assert param.col_attr.data.payload.device.type == init_device.type, \
|
assert param.col_attr.data.payload.device.type == init_device.type, \
|
||||||
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
||||||
|
|
||||||
print(f'cuda usgae {ModelDataTracer().cuda_usage}')
|
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||||
print(f'numel {model_numel_tensor}')
|
print(f'numel {model_numel_tensor}')
|
||||||
if init_device.type == 'cuda':
|
if init_device.type == 'cuda':
|
||||||
assert (ModelDataTracer().cuda_usage > 0)
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
Loading…
Reference in New Issue