From 8d8c5407c0f677f9f356a96c2bef7dad83af12dd Mon Sep 17 00:00:00 2001 From: Jiarui Fang <fangjiarui123@gmail.com> Date: Fri, 25 Mar 2022 18:03:32 +0800 Subject: [PATCH] [zero] refactor model data tracing (#522) --- .../memory_tracer/model_data_memtracer.py | 28 +++++--- colossalai/utils/memory_utils/utils.py | 21 +++--- colossalai/zero/init_ctx/init_context.py | 8 +-- .../bucket_tensor_shard_strategy.py | 7 ++ .../zero/shard_utils/tensor_shard_strategy.py | 18 ++++- .../zero/sharded_param/sharded_tensor.py | 5 +- tests/test_utils/test_tensor_move.py | 66 +++++++++++++++++++ .../test_init_context.py | 3 +- 8 files changed, 128 insertions(+), 28 deletions(-) create mode 100644 tests/test_utils/test_tensor_move.py diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index e8cb9f7c6..fafe31690 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -22,6 +22,7 @@ class ModelDataTracer(metaclass=SingletonMeta): def __init__(self) -> None: self._cuda_usage = 0 + self._cpu_usage = 0 self._start_flag = False def start(self) -> None: @@ -30,22 +31,33 @@ class ModelDataTracer(metaclass=SingletonMeta): def close(self) -> None: self._start_flag = False - def add_tensor(self, t: torch.Tensor) -> None: + def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None: if not self._start_flag: return - assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor" - mem_use = _col_tensor_mem_usage(t) - self._cuda_usage += mem_use + t_payload = t.payload if isinstance(t, ShardedTensor) else t + mem_use = _col_tensor_mem_usage(t_payload) + if t_payload.device.type == 'cuda': + self._cuda_usage += mem_use + elif t_payload.device.type == 'cpu': + self._cpu_usage += mem_use + else: + raise TypeError - def delete_tensor(self, t: torch.Tensor) -> None: + def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None: if not self._start_flag: return - assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor" - mem_use = _col_tensor_mem_usage(t) - self._cuda_usage -= mem_use + t_payload = t.payload if isinstance(t, ShardedTensor) else t + mem_use = _col_tensor_mem_usage(t_payload) + if t_payload.device.type == 'cuda': + self._cuda_usage -= mem_use + elif t_payload.device.type == 'cpu': + self._cpu_usage -= mem_use + else: + raise TypeError def clear(self) -> None: self._cuda_usage = 0 + self._cpu_usage = 0 @property def cpu_usage(self): diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py index 52bb487d0..df41ac95d 100644 --- a/colossalai/utils/memory_utils/utils.py +++ b/colossalai/utils/memory_utils/utils.py @@ -3,7 +3,7 @@ from colossalai.utils import get_current_device from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from typing import Union, Optional +from typing import Union _GLOBAL_CUDA_MEM_FRACTION = 1.0 @@ -52,11 +52,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t tgt_t_payload = tgt_t.data tgt_dev = tgt_t_payload.device - if src_dev.type == 'cuda' and tgt_dev.type == 'cpu': - GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload) - elif src_dev.type == 'cpu' and tgt_dev.type == 'cuda': - GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload) + GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload) tgt_t_payload.copy_(src_t_payload) + GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload) # remove payload of src_t if isinstance(src_t, ShardedTensor): @@ -65,7 +63,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype) -def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None: +def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], + target_device: torch.device, + use_tracer: bool = True) -> None: """ move a tensor to the target_device Args: @@ -84,13 +84,11 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], ta # deal with torch.device('cpu') and torch.device('cpu:0) if t_payload.device.type == target_device.type: return - - if target_device.type == 'cuda': - GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload) - elif target_device.type == 'cpu': + if use_tracer: GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload) - t_payload.data = t_payload.data.to(target_device) + if use_tracer: + GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload) def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None: @@ -115,3 +113,4 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None: # TODO() optimize the tensor moving with non-blocking GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload) t_payload.data = t_payload.data.cpu() + GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 9ff4a81c5..32352e469 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -177,13 +177,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.initialized_param_list.append(param) + GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor) + if self.shard_param: self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) - if param.col_attr.sharded_data_tensor.device.type == 'cuda': - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) - # if param.col_attr.grad and self.shard_grad: - # self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group) - # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) + # We must cast buffers # If we use BN, buffers may be on CPU and Float # We must cast them diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index 90b447de1..06683af6a 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -7,6 +7,7 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from torch._utils import _flatten_dense_tensors as flatten from .tensor_shard_strategy import TensorShardStrategy +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER class BucketTensorShardStrategy(TensorShardStrategy): @@ -17,6 +18,9 @@ class BucketTensorShardStrategy(TensorShardStrategy): """ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): + for t in tensor_list: + GLOBAL_MODEL_DATA_TRACER.delete_tensor(t) + tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] if len(tensor_list) == 0: return @@ -46,3 +50,6 @@ class BucketTensorShardStrategy(TensorShardStrategy): t.reset_payload(gathered_payload) t.is_sharded = False offset += tensor_numels[i] + + for t in tensor_list: + GLOBAL_MODEL_DATA_TRACER.add_tensor(t) diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 31210a190..25914f6f3 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -3,13 +3,16 @@ from typing import List, Optional import torch import torch.distributed as dist from colossalai.utils import get_current_device +from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER class TensorShardStrategy(BaseShardStrategy): - """A naive implementation which shard each tensor evenly over all ranks + """ + A naive implementation which shard each tensor evenly over all ranks """ def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): @@ -21,13 +24,22 @@ class TensorShardStrategy(BaseShardStrategy): self._gather_tensor(t, process_group) def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): + """ Shard tensor among processes. + + Args: + t (ShardedTensor): a tensor to be sharded. + process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. + Defaults to None. + """ if t.is_sharded: return if t.payload.device.type == 'cuda': assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ f" but current cuda device is {get_current_device()}" + GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.reset_payload(sharded_payload) + GLOBAL_MODEL_DATA_TRACER.add_tensor(t.payload) t.is_sharded = True def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): @@ -44,8 +56,10 @@ class TensorShardStrategy(BaseShardStrategy): else: buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device())) + GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload) dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False) gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape) t.reset_payload(gathered_payload) - t.to(target_device) + colo_model_data_tensor_move_inline(t, target_device, use_tracer=False) + GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload) t.is_sharded = False diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index cde257d77..c678f22da 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -56,7 +56,10 @@ class ShardedTensor(object): return self._origin_dtype def to(self, device: torch.device): - self._payload = self._payload.to(device) + raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor") + + def to_(self, device: torch.device): + raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor") @property def shape(self): diff --git a/tests/test_utils/test_tensor_move.py b/tests/test_utils/test_tensor_move.py new file mode 100644 index 000000000..223db83ad --- /dev/null +++ b/tests/test_utils/test_tensor_move.py @@ -0,0 +1,66 @@ +import pytest + +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER +from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.zero.sharded_param import ShardedTensor + +import colossalai + +import torch + +from functools import partial +import torch.multiprocessing as mp +from colossalai.utils import free_port + + +def _run_colo_model_data_tensor_move_inline(): + assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0) + GLOBAL_MODEL_DATA_TRACER.start() + + for t in [torch.randn(2, 3), ShardedTensor(torch.randn(2, 3))]: + GLOBAL_MODEL_DATA_TRACER.add_tensor(t) + assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4 + assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0 + colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}")) + assert t.device == torch.device(f"cuda:{get_current_device()}") + assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0 + assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4 + GLOBAL_MODEL_DATA_TRACER.clear() + + GLOBAL_MODEL_DATA_TRACER.close() + + +def _run_colo_model_data_tensor_move(): + assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0) + GLOBAL_MODEL_DATA_TRACER.start() + + for t in [(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device())), + (ShardedTensor(torch.ones(2, 3)), ShardedTensor(torch.zeros(2, 3).cuda(get_current_device())))]: + cpu_t, cuda_t = t + GLOBAL_MODEL_DATA_TRACER.add_tensor(cpu_t) + assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4 + assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0 + colo_model_data_tensor_move(cpu_t, cuda_t) + assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0 + assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4 + GLOBAL_MODEL_DATA_TRACER.clear() + + GLOBAL_MODEL_DATA_TRACER.close() + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_colo_model_data_tensor_move_inline() + _run_colo_model_data_tensor_move() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 4]) +def test_tensor_move(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_tensor_move(4) diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index 4b5d9edd8..381612d1f 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -48,6 +48,8 @@ def run_model_test(init_device_type, shard_strategy_class): f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' if init_device.type == 'cuda': assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) + else: + assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0) GLOBAL_MODEL_DATA_TRACER.clear() @@ -65,5 +67,4 @@ def test_zero_init_context(world_size): if __name__ == '__main__': - # test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy) test_zero_init_context(4)