mirror of https://github.com/hpcaitech/ColossalAI
[zero] refactor model data tracing (#537)
parent
a590ed0ba3
commit
705f56107c
|
@ -5,8 +5,6 @@ import torch.distributed as dist
|
|||
from colossalai.registry import OPHOOKS
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
|
|
|
@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
|||
from colossalai.utils import get_current_device
|
||||
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class SamplingCounter:
|
||||
|
@ -40,6 +41,20 @@ class MemStatsCollector:
|
|||
|
||||
self._start_flag = False
|
||||
|
||||
@property
|
||||
def overall_cuda(self):
|
||||
return self._overall_cuda
|
||||
|
||||
@property
|
||||
def model_data_cuda(self):
|
||||
return self._model_data_cuda
|
||||
|
||||
@property
|
||||
def non_model_data_cuda(self):
|
||||
"""Non model data stats
|
||||
"""
|
||||
return [(v1 - v2) for v1, v2 in zip(self._overall_cuda, self._model_data_cuda)]
|
||||
|
||||
def start_collection(self):
|
||||
self._start_flag = True
|
||||
|
||||
|
@ -58,7 +73,7 @@ class MemStatsCollector:
|
|||
self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')))
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def fetch_memstats(self) -> (int, int):
|
||||
def fetch_memstats(self) -> Tuple[int, int]:
|
||||
"""
|
||||
returns cuda usage of model data and overall cuda usage.
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
import torch
|
||||
from typing import Union
|
||||
from typing import Union, Tuple, Optional
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
||||
|
@ -12,60 +13,78 @@ def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
|||
return target.numel() * target.element_size()
|
||||
|
||||
|
||||
def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
"""
|
||||
Trace the model memory usage.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch model
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
|
||||
"""
|
||||
|
||||
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
|
||||
if t is None:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor)
|
||||
_cpu_mem_usage, _cuda_mem_usage = 0, 0
|
||||
if t.device.type == 'cpu':
|
||||
_cpu_mem_usage += t.numel() * t.element_size()
|
||||
elif t.device.type == 'cuda':
|
||||
_cuda_mem_usages += t.numel() * t.element_size()
|
||||
return _cuda_mem_usage, _cpu_mem_usage
|
||||
|
||||
cuda_mem_usage = 0
|
||||
cpu_mem_usage = 0
|
||||
for param in model.parameters():
|
||||
if hasattr(param, 'col_attr'):
|
||||
para_cuda, param_cpu = param.col_attr.get_memory_usage()
|
||||
cuda_mem_usage += para_cuda
|
||||
cpu_mem_usage += param_cpu
|
||||
else:
|
||||
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
t_cuda, t_cpu = _get_tensor_mem_use(param.grad)
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
|
||||
return cuda_mem_usage, cpu_mem_usage
|
||||
|
||||
|
||||
class ModelDataTracer(metaclass=SingletonMeta):
|
||||
"""
|
||||
A tracer singleton to trace model data usage during runtime.
|
||||
The tracer is designed to trace the memory layout change during model-data tensors allocation, releasing, and moving.
|
||||
To achieve this goal, the developers have to call `ModelDataTracer` in the corresponding code explicitly.
|
||||
NOTE() now the class only trace cuda memory usage
|
||||
You have to register a model on the singleton first.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
self._cpu_usage = 0
|
||||
self._start_flag = False
|
||||
self._logger = DistributedLogger("ModelDataTracer")
|
||||
self._model = None
|
||||
|
||||
def start(self) -> None:
|
||||
self._start_flag = True
|
||||
def _get_mem_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
get the memory usage of the model registered.
|
||||
Returns:
|
||||
Tuple[int, int]: cuda, cpu mem usage
|
||||
"""
|
||||
if self._model is None:
|
||||
self._logger.warning("The Global ModelDataTracer is using, but no model is registered on it.")
|
||||
return 0, 0
|
||||
return col_model_data_mem_usage(self._model)
|
||||
|
||||
def close(self) -> None:
|
||||
self._start_flag = False
|
||||
|
||||
def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
mem_use = _col_tensor_mem_usage(t_payload)
|
||||
if t_payload.device.type == 'cuda':
|
||||
self._cuda_usage += mem_use
|
||||
elif t_payload.device.type == 'cpu':
|
||||
self._cpu_usage += mem_use
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
mem_use = _col_tensor_mem_usage(t_payload)
|
||||
if t_payload.device.type == 'cuda':
|
||||
self._cuda_usage -= mem_use
|
||||
elif t_payload.device.type == 'cpu':
|
||||
self._cpu_usage -= mem_use
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def clear(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
self._cpu_usage = 0
|
||||
def register_model(self, model) -> None:
|
||||
self._model = model
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
return self._cpu_usage
|
||||
_, cpu_usage = self._get_mem_usage()
|
||||
return cpu_usage
|
||||
|
||||
@property
|
||||
def cuda_usage(self):
|
||||
return self._cuda_usage
|
||||
cuda_usage, _ = self._get_mem_usage()
|
||||
return cuda_usage
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -14,7 +13,6 @@ def test_mem_collector():
|
|||
collector.sample_memstats()
|
||||
|
||||
m_a = torch.randn(10).cuda()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(m_a)
|
||||
b = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 1
|
||||
|
@ -35,8 +33,7 @@ def test_mem_collector():
|
|||
cuda_use, overall_use = collector.fetch_memstats()
|
||||
print(cuda_use, overall_use)
|
||||
|
||||
print(collector._model_data_cuda)
|
||||
print(collector._overall_cuda)
|
||||
print(collector.overall_cuda)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
@ -52,9 +51,7 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
|
|||
tgt_t_payload = tgt_t.data
|
||||
tgt_dev = tgt_t_payload.device
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
|
||||
tgt_t_payload.copy_(src_t_payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
|
||||
|
||||
# remove payload of src_t
|
||||
if isinstance(src_t, ShardedTensor):
|
||||
|
@ -84,11 +81,7 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
|
|||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||
if t_payload.device.type == target_device.type:
|
||||
return
|
||||
if use_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
t_payload.data = t_payload.data.to(target_device)
|
||||
if use_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||
|
@ -111,9 +104,7 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
|||
return
|
||||
|
||||
# TODO() optimize the tensor moving with non-blocking
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
t_payload.data = t_payload.data.cpu()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
|
@ -129,5 +120,4 @@ def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device
|
|||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
|
||||
ret = t_payload.to(target_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(ret)
|
||||
return ret
|
||||
|
|
|
@ -4,8 +4,6 @@ from typing import Optional
|
|||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
|
@ -130,7 +128,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
The Callback function when entering the context
|
||||
"""
|
||||
self.logger = get_dist_logger("ZeroInitContext")
|
||||
GLOBAL_MODEL_DATA_TRACER.start()
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when exiting context.
|
||||
|
@ -141,12 +138,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
param.col_attr.remove_torch_payload()
|
||||
|
||||
del self.initialized_param_list
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
||||
self.logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
|
||||
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
|
||||
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module):
|
||||
"""
|
||||
|
@ -176,9 +167,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
|||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
|
@ -18,8 +17,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
|||
"""
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
|
||||
|
||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||
if len(tensor_list) == 0:
|
||||
|
@ -50,6 +47,3 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
|||
t.reset_payload(gathered_payload)
|
||||
t.is_sharded = False
|
||||
offset += tensor_numels[i]
|
||||
|
||||
for t in tensor_list:
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t)
|
||||
|
|
|
@ -7,7 +7,6 @@ from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, col
|
|||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.commons import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
|
@ -36,10 +35,8 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
if t.payload.device.type == 'cuda':
|
||||
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.reset_payload(sharded_payload)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t.payload)
|
||||
t.is_sharded = True
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
|
@ -56,10 +53,8 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device, use_tracer=False)
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
|
||||
t.is_sharded = False
|
||||
|
|
|
@ -11,6 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
|
|||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
@ -83,6 +84,7 @@ class ShardedModelV2(nn.Module):
|
|||
# Init Memory Statistics Collector
|
||||
self._use_memory_tracer = use_memory_tracer
|
||||
if self._use_memory_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
|
@ -147,14 +149,16 @@ class ShardedModelV2(nn.Module):
|
|||
def _update_memstats(self):
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
self._memstats_collector.finish_collection()
|
||||
self.logger.info(f'model data cuda, {self._memstats_collector.model_data_cuda}')
|
||||
self.logger.info(f'non-model data cuda, {self._memstats_collector.non_model_data_cuda}')
|
||||
|
||||
if self._memstats_collector:
|
||||
self._memstats_collector.reset_sampling_cnter()
|
||||
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
||||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda)
|
||||
|
||||
self._cuda_margin_space = colo_cuda_memory_capacity() - max(self._memstats_collector.overall_cuda)
|
||||
self._iter_cnter += 1
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -9,7 +9,6 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from torch import Tensor
|
||||
|
@ -218,9 +217,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# We must set grad to None
|
||||
# Because we will judge whether local grad accumulation
|
||||
# is enabled by wheter grad is None
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(p.grad)
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def sync_grad(self):
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
|
@ -13,22 +12,15 @@ import torch.multiprocessing as mp
|
|||
def run_tensor_move(rank):
|
||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||
GLOBAL_MODEL_DATA_TRACER.start()
|
||||
|
||||
src_t = torch.ones(2, 3).cuda()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24)
|
||||
tgt_t = torch.zeros(2, 3)
|
||||
|
||||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
|
||||
src_t = torch.ones(2, 3)
|
||||
tgt_t = torch.zeros(2, 3).cuda().half()
|
||||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||
# the src_t has been removed
|
||||
assert (src_t.numel() == 0)
|
||||
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
|
@ -36,15 +28,11 @@ def run_tensor_move(rank):
|
|||
src_t = ShardedTensor(torch.ones(2, 3))
|
||||
tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half())
|
||||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
|
||||
assert (tgt_t.device.type == 'cuda')
|
||||
colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu'))
|
||||
assert (tgt_t.device.type == 'cpu')
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
|
||||
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
|
|
|
@ -1,52 +1,28 @@
|
|||
import pytest
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
|
||||
import colossalai
|
||||
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def _run_colo_model_data_tensor_move_inline():
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||
GLOBAL_MODEL_DATA_TRACER.start()
|
||||
|
||||
for t in [torch.randn(2, 3), ShardedTensor(torch.randn(2, 3))]:
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t)
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
|
||||
colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}"))
|
||||
assert t.device == torch.device(f"cuda:{get_current_device()}")
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
|
||||
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
|
||||
|
||||
def _run_colo_model_data_tensor_move():
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||
GLOBAL_MODEL_DATA_TRACER.start()
|
||||
|
||||
for t in [(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device())),
|
||||
(ShardedTensor(torch.ones(2, 3)), ShardedTensor(torch.zeros(2, 3).cuda(get_current_device())))]:
|
||||
cpu_t, cuda_t = t
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(cpu_t)
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
|
||||
colo_model_data_tensor_move(cpu_t, cuda_t)
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
|
||||
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
|
||||
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -10,19 +10,21 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import col_model_data_mem_usage
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(init_device_type, shard_strategy_class):
|
||||
logger = get_dist_logger("test_zero_init")
|
||||
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, _, _, _, _ = get_components_func()
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
|
@ -32,6 +34,8 @@ def run_model_test(init_device_type, shard_strategy_class):
|
|||
init_device = torch.device("cpu")
|
||||
else:
|
||||
continue
|
||||
|
||||
model_numel_tensor = torch.zeros(1, dtype=torch.int)
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=init_device,
|
||||
shard_strategy=shard_strategy_class(),
|
||||
|
@ -46,11 +50,13 @@ def run_model_test(init_device_type, shard_strategy_class):
|
|||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
if init_device.type == 'cuda':
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||
else:
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
|
||||
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||
|
||||
cuda_mem_use, cpu_mem_use = col_model_data_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
|
||||
logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
|
||||
logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue