mirror of https://github.com/hpcaitech/ColossalAI
[zero] refactor model data tracing (#522)
parent
3601b2bad0
commit
8d8c5407c0
|
@ -22,6 +22,7 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._cuda_usage = 0
|
self._cuda_usage = 0
|
||||||
|
self._cpu_usage = 0
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
|
@ -30,22 +31,33 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
|
|
||||||
def add_tensor(self, t: torch.Tensor) -> None:
|
def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||||
if not self._start_flag:
|
if not self._start_flag:
|
||||||
return
|
return
|
||||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
|
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||||
mem_use = _col_tensor_mem_usage(t)
|
mem_use = _col_tensor_mem_usage(t_payload)
|
||||||
self._cuda_usage += mem_use
|
if t_payload.device.type == 'cuda':
|
||||||
|
self._cuda_usage += mem_use
|
||||||
|
elif t_payload.device.type == 'cpu':
|
||||||
|
self._cpu_usage += mem_use
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
def delete_tensor(self, t: torch.Tensor) -> None:
|
def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
|
||||||
if not self._start_flag:
|
if not self._start_flag:
|
||||||
return
|
return
|
||||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
|
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||||
mem_use = _col_tensor_mem_usage(t)
|
mem_use = _col_tensor_mem_usage(t_payload)
|
||||||
self._cuda_usage -= mem_use
|
if t_payload.device.type == 'cuda':
|
||||||
|
self._cuda_usage -= mem_use
|
||||||
|
elif t_payload.device.type == 'cpu':
|
||||||
|
self._cpu_usage -= mem_use
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self._cuda_usage = 0
|
self._cuda_usage = 0
|
||||||
|
self._cpu_usage = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cpu_usage(self):
|
def cpu_usage(self):
|
||||||
|
|
|
@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
from typing import Union, Optional
|
from typing import Union
|
||||||
|
|
||||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||||
|
|
||||||
|
@ -52,11 +52,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
|
||||||
tgt_t_payload = tgt_t.data
|
tgt_t_payload = tgt_t.data
|
||||||
tgt_dev = tgt_t_payload.device
|
tgt_dev = tgt_t_payload.device
|
||||||
|
|
||||||
if src_dev.type == 'cuda' and tgt_dev.type == 'cpu':
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
|
||||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
|
|
||||||
elif src_dev.type == 'cpu' and tgt_dev.type == 'cuda':
|
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
|
|
||||||
tgt_t_payload.copy_(src_t_payload)
|
tgt_t_payload.copy_(src_t_payload)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
|
||||||
|
|
||||||
# remove payload of src_t
|
# remove payload of src_t
|
||||||
if isinstance(src_t, ShardedTensor):
|
if isinstance(src_t, ShardedTensor):
|
||||||
|
@ -65,7 +63,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
|
||||||
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
||||||
|
|
||||||
|
|
||||||
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None:
|
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
|
||||||
|
target_device: torch.device,
|
||||||
|
use_tracer: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
move a tensor to the target_device
|
move a tensor to the target_device
|
||||||
Args:
|
Args:
|
||||||
|
@ -84,13 +84,11 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], ta
|
||||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||||
if t_payload.device.type == target_device.type:
|
if t_payload.device.type == target_device.type:
|
||||||
return
|
return
|
||||||
|
if use_tracer:
|
||||||
if target_device.type == 'cuda':
|
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
|
||||||
elif target_device.type == 'cpu':
|
|
||||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||||
|
|
||||||
t_payload.data = t_payload.data.to(target_device)
|
t_payload.data = t_payload.data.to(target_device)
|
||||||
|
if use_tracer:
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||||
|
|
||||||
|
|
||||||
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||||
|
@ -115,3 +113,4 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||||
# TODO() optimize the tensor moving with non-blocking
|
# TODO() optimize the tensor moving with non-blocking
|
||||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||||
t_payload.data = t_payload.data.cpu()
|
t_payload.data = t_payload.data.cpu()
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||||
|
|
|
@ -177,13 +177,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
self.initialized_param_list.append(param)
|
self.initialized_param_list.append(param)
|
||||||
|
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor)
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
|
||||||
# if param.col_attr.grad and self.shard_grad:
|
|
||||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
|
||||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
|
||||||
# We must cast buffers
|
# We must cast buffers
|
||||||
# If we use BN, buffers may be on CPU and Float
|
# If we use BN, buffers may be on CPU and Float
|
||||||
# We must cast them
|
# We must cast them
|
||||||
|
|
|
@ -7,6 +7,7 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
from torch._utils import _flatten_dense_tensors as flatten
|
from torch._utils import _flatten_dense_tensors as flatten
|
||||||
|
|
||||||
from .tensor_shard_strategy import TensorShardStrategy
|
from .tensor_shard_strategy import TensorShardStrategy
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
|
|
||||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||||
|
@ -17,6 +18,9 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||||
|
for t in tensor_list:
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
|
||||||
|
|
||||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||||
if len(tensor_list) == 0:
|
if len(tensor_list) == 0:
|
||||||
return
|
return
|
||||||
|
@ -46,3 +50,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||||
t.reset_payload(gathered_payload)
|
t.reset_payload(gathered_payload)
|
||||||
t.is_sharded = False
|
t.is_sharded = False
|
||||||
offset += tensor_numels[i]
|
offset += tensor_numels[i]
|
||||||
|
|
||||||
|
for t in tensor_list:
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(t)
|
||||||
|
|
|
@ -3,13 +3,16 @@ from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.shard_utils.commons import get_shard
|
from colossalai.zero.shard_utils.commons import get_shard
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
|
||||||
|
|
||||||
class TensorShardStrategy(BaseShardStrategy):
|
class TensorShardStrategy(BaseShardStrategy):
|
||||||
"""A naive implementation which shard each tensor evenly over all ranks
|
"""
|
||||||
|
A naive implementation which shard each tensor evenly over all ranks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||||
|
@ -21,13 +24,22 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||||
self._gather_tensor(t, process_group)
|
self._gather_tensor(t, process_group)
|
||||||
|
|
||||||
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||||
|
""" Shard tensor among processes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (ShardedTensor): a tensor to be sharded.
|
||||||
|
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
if t.is_sharded:
|
if t.is_sharded:
|
||||||
return
|
return
|
||||||
if t.payload.device.type == 'cuda':
|
if t.payload.device.type == 'cuda':
|
||||||
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||||
f" but current cuda device is {get_current_device()}"
|
f" but current cuda device is {get_current_device()}"
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
|
||||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||||
t.reset_payload(sharded_payload)
|
t.reset_payload(sharded_payload)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(t.payload)
|
||||||
t.is_sharded = True
|
t.is_sharded = True
|
||||||
|
|
||||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||||
|
@ -44,8 +56,10 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||||
else:
|
else:
|
||||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
|
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
|
||||||
|
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
|
||||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||||
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
|
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||||
t.reset_payload(gathered_payload)
|
t.reset_payload(gathered_payload)
|
||||||
t.to(target_device)
|
colo_model_data_tensor_move_inline(t, target_device, use_tracer=False)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
|
||||||
t.is_sharded = False
|
t.is_sharded = False
|
||||||
|
|
|
@ -56,7 +56,10 @@ class ShardedTensor(object):
|
||||||
return self._origin_dtype
|
return self._origin_dtype
|
||||||
|
|
||||||
def to(self, device: torch.device):
|
def to(self, device: torch.device):
|
||||||
self._payload = self._payload.to(device)
|
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||||
|
|
||||||
|
def to_(self, device: torch.device):
|
||||||
|
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||||
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
def _run_colo_model_data_tensor_move_inline():
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.start()
|
||||||
|
|
||||||
|
for t in [torch.randn(2, 3), ShardedTensor(torch.randn(2, 3))]:
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(t)
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
|
||||||
|
colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}"))
|
||||||
|
assert t.device == torch.device(f"cuda:{get_current_device()}")
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||||
|
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _run_colo_model_data_tensor_move():
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.start()
|
||||||
|
|
||||||
|
for t in [(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device())),
|
||||||
|
(ShardedTensor(torch.ones(2, 3)), ShardedTensor(torch.zeros(2, 3).cuda(get_current_device())))]:
|
||||||
|
cpu_t, cuda_t = t
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(cpu_t)
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
|
||||||
|
colo_model_data_tensor_move(cpu_t, cuda_t)
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
|
||||||
|
assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||||
|
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
_run_colo_model_data_tensor_move_inline()
|
||||||
|
_run_colo_model_data_tensor_move()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("world_size", [1, 4])
|
||||||
|
def test_tensor_move(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_tensor_move(4)
|
|
@ -48,6 +48,8 @@ def run_model_test(init_device_type, shard_strategy_class):
|
||||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||||
if init_device.type == 'cuda':
|
if init_device.type == 'cuda':
|
||||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
|
||||||
|
else:
|
||||||
|
assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
|
||||||
GLOBAL_MODEL_DATA_TRACER.clear()
|
GLOBAL_MODEL_DATA_TRACER.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,5 +67,4 @@ def test_zero_init_context(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
|
|
||||||
test_zero_init_context(4)
|
test_zero_init_context(4)
|
||||||
|
|
Loading…
Reference in New Issue