[memory] add model data tensor moving api (#503)

pull/506/head
Jiarui Fang 3 years ago committed by GitHub
parent 65ad47c35c
commit 0035b7be07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
sync_model_param)
from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
from .memory import report_memory_usage
from .memory_utils.memory_monitor import report_memory_usage
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector

@ -1,9 +0,0 @@
import torch
from colossalai.utils import get_current_device
def col_cuda_memory_capacity():
"""
Get cuda memory capacity of the current cuda.
"""
return torch.cuda.get_device_properties(get_current_device()).total_memory

@ -1,19 +0,0 @@
import torch
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
def col_move_to_cpu(t: torch.Tensor):
assert isinstance(t, torch.Tensor)
if t.device.type == 'cpu':
return
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
t.data = t.data.cpu()
def col_modeldata_allocate(device: torch.device) -> torch.Tensor:
pass
def col_modeldata_release(t: torch.Tensor):
pass

@ -1,11 +0,0 @@
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union
import torch
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()

@ -1,6 +1,15 @@
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
import torch
from typing import Union
def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
if isinstance(t, ShardedTensor):
target = t.payload
else:
target = t
return target.numel() * target.element_size()
class ModelDataTracer(metaclass=SingletonMeta):
@ -16,12 +25,12 @@ class ModelDataTracer(metaclass=SingletonMeta):
def add_tensor(self, t: torch.Tensor):
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
mem_use = col_tensor_mem_usage(t)
mem_use = _col_tensor_mem_usage(t)
self._cuda_usage += mem_use
def delete_tensor(self, t: torch.Tensor):
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
mem_use = col_tensor_mem_usage(t)
mem_use = _col_tensor_mem_usage(t)
self._cuda_usage -= mem_use
@property

@ -63,5 +63,5 @@ def report_memory_usage(message, logger=None, report_cpu=False):
logger.info(full_log)
# get the peak memory to report correct data, so reset the counter for the next call
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
torch.cuda.reset_peak_memory_stats()

@ -0,0 +1,59 @@
import torch
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Union
def colo_cuda_memory_capacity():
"""
Get cuda memory capacity of the current cuda.
"""
return torch.cuda.get_device_properties(get_current_device()).total_memory
def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t: Union[ShardedTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[ShardedTensor, torch.Tensor]): source tensor
tgt_t (Union[ShardedTensor, torch.Tensor]): target tensor
"""
if isinstance(src_t, ShardedTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if isinstance(tgt_t, ShardedTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
tgt_dev = tgt_t_payload.device
if src_dev.type == 'cuda' and tgt_dev.type == 'cpu':
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
elif src_dev.type == 'cpu' and tgt_dev.type == 'cuda':
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if isinstance(src_t, ShardedTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_move_to_cpu(t: torch.Tensor):
assert isinstance(t, torch.Tensor)
if t.device.type == 'cpu':
return
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
t.data = t.data.cpu()

@ -11,8 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils.commons.memory import col_cuda_memory_capacity
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
@ -152,7 +151,7 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = col_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda)
self._cuda_margin_space = colo_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda)
self._iter_cnter += 1
@ -201,7 +200,7 @@ class ShardedModelV2(nn.Module):
else:
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_grad:
col_move_to_cpu(grad)
colo_model_data_move_to_cpu(grad)
if p.col_attr.fp32_grad is not None:
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))

@ -25,8 +25,18 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer):
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3.
You must use `ShardedOptimizerV2` with `ShardedModelV2`.
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
which is detected by a runtime memory tracer.
We place as many OS chunks in the margin space as possible.
The size of margin space can be controlled by `gpu_margin_mem_ratio`
If it is set as 0.0, it is the same as classical ZeRO optimizer.
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
Args:
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the

@ -0,0 +1,49 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
from colossalai.utils import free_port
from colossalai.zero.sharded_param import ShardedTensor
import colossalai
import torch
from functools import partial
import torch.multiprocessing as mp
import pytest
def run_tensor_move(rank):
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
src_t = torch.ones(2, 3).cuda()
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24)
tgt_t = torch.zeros(2, 3)
colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
src_t = torch.ones(2, 3)
tgt_t = torch.zeros(2, 3).cuda().half()
colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
# the src_t has been removed
assert (src_t.numel() == 0)
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
src_t = ShardedTensor(torch.ones(2, 3))
tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half())
colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1)
if __name__ == '__main__':
test_tensor_move()
Loading…
Cancel
Save