mirror of https://github.com/hpcaitech/ColossalAI
[memory] add model data tensor moving api (#503)
parent
65ad47c35c
commit
0035b7be07
@ -1,9 +0,0 @@
|
|||||||
import torch
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
|
|
||||||
def col_cuda_memory_capacity():
|
|
||||||
"""
|
|
||||||
Get cuda memory capacity of the current cuda.
|
|
||||||
"""
|
|
||||||
return torch.cuda.get_device_properties(get_current_device()).total_memory
|
|
@ -1,19 +0,0 @@
|
|||||||
import torch
|
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
|
||||||
|
|
||||||
|
|
||||||
def col_move_to_cpu(t: torch.Tensor):
|
|
||||||
assert isinstance(t, torch.Tensor)
|
|
||||||
if t.device.type == 'cpu':
|
|
||||||
return
|
|
||||||
|
|
||||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
|
|
||||||
t.data = t.data.cpu()
|
|
||||||
|
|
||||||
|
|
||||||
def col_modeldata_allocate(device: torch.device) -> torch.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def col_modeldata_release(t: torch.Tensor):
|
|
||||||
pass
|
|
@ -1,11 +0,0 @@
|
|||||||
from colossalai.zero.sharded_param import ShardedTensor
|
|
||||||
from typing import Union
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int:
|
|
||||||
if isinstance(t, ShardedTensor):
|
|
||||||
target = t.payload
|
|
||||||
else:
|
|
||||||
target = t
|
|
||||||
return target.numel() * target.element_size()
|
|
@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def colo_cuda_memory_capacity():
|
||||||
|
"""
|
||||||
|
Get cuda memory capacity of the current cuda.
|
||||||
|
"""
|
||||||
|
return torch.cuda.get_device_properties(get_current_device()).total_memory
|
||||||
|
|
||||||
|
|
||||||
|
def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t: Union[ShardedTensor,
|
||||||
|
torch.Tensor]) -> None:
|
||||||
|
"""
|
||||||
|
A colossal API for model data tensor move.
|
||||||
|
The src and target tensors could be resident on both CPU and GPU.
|
||||||
|
|
||||||
|
NOTE() The source tensor payload will be removed after this function.
|
||||||
|
|
||||||
|
The function will record the communication volume between CPU and GPU.
|
||||||
|
Args:
|
||||||
|
t_src (Union[ShardedTensor, torch.Tensor]): source tensor
|
||||||
|
tgt_t (Union[ShardedTensor, torch.Tensor]): target tensor
|
||||||
|
"""
|
||||||
|
if isinstance(src_t, ShardedTensor):
|
||||||
|
src_t_payload = src_t.payload
|
||||||
|
else:
|
||||||
|
src_t_payload = src_t.data
|
||||||
|
src_dev = src_t_payload.device
|
||||||
|
if isinstance(tgt_t, ShardedTensor):
|
||||||
|
tgt_t_payload = tgt_t.payload
|
||||||
|
else:
|
||||||
|
tgt_t_payload = tgt_t.data
|
||||||
|
tgt_dev = tgt_t_payload.device
|
||||||
|
|
||||||
|
if src_dev.type == 'cuda' and tgt_dev.type == 'cpu':
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
|
||||||
|
elif src_dev.type == 'cpu' and tgt_dev.type == 'cuda':
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
|
||||||
|
tgt_t_payload.copy_(src_t_payload)
|
||||||
|
|
||||||
|
# remove payload of src_t
|
||||||
|
if isinstance(src_t, ShardedTensor):
|
||||||
|
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
|
||||||
|
else:
|
||||||
|
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def colo_model_data_move_to_cpu(t: torch.Tensor):
|
||||||
|
assert isinstance(t, torch.Tensor)
|
||||||
|
if t.device.type == 'cpu':
|
||||||
|
return
|
||||||
|
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
|
||||||
|
t.data = t.data.cpu()
|
@ -0,0 +1,49 @@
|
|||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def run_tensor_move(rank):
|
||||||
|
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||||
|
|
||||||
|
src_t = torch.ones(2, 3).cuda()
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24)
|
||||||
|
tgt_t = torch.zeros(2, 3)
|
||||||
|
|
||||||
|
colo_model_data_tensor_move(src_t, tgt_t)
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||||
|
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||||
|
|
||||||
|
src_t = torch.ones(2, 3)
|
||||||
|
tgt_t = torch.zeros(2, 3).cuda().half()
|
||||||
|
colo_model_data_tensor_move(src_t, tgt_t)
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||||
|
# the src_t has been removed
|
||||||
|
assert (src_t.numel() == 0)
|
||||||
|
assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||||
|
|
||||||
|
src_t = ShardedTensor(torch.ones(2, 3))
|
||||||
|
tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half())
|
||||||
|
colo_model_data_tensor_move(src_t, tgt_t)
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||||
|
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_move():
|
||||||
|
mp.spawn(run_tensor_move, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_tensor_move()
|
Loading…
Reference in new issue