From 8d8c5407c0f677f9f356a96c2bef7dad83af12dd Mon Sep 17 00:00:00 2001
From: Jiarui Fang <fangjiarui123@gmail.com>
Date: Fri, 25 Mar 2022 18:03:32 +0800
Subject: [PATCH] [zero] refactor model data tracing (#522)

---
 .../memory_tracer/model_data_memtracer.py     | 28 +++++---
 colossalai/utils/memory_utils/utils.py        | 21 +++---
 colossalai/zero/init_ctx/init_context.py      |  8 +--
 .../bucket_tensor_shard_strategy.py           |  7 ++
 .../zero/shard_utils/tensor_shard_strategy.py | 18 ++++-
 .../zero/sharded_param/sharded_tensor.py      |  5 +-
 tests/test_utils/test_tensor_move.py          | 66 +++++++++++++++++++
 .../test_init_context.py                      |  3 +-
 8 files changed, 128 insertions(+), 28 deletions(-)
 create mode 100644 tests/test_utils/test_tensor_move.py

diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py
index e8cb9f7c6..fafe31690 100644
--- a/colossalai/utils/memory_tracer/model_data_memtracer.py
+++ b/colossalai/utils/memory_tracer/model_data_memtracer.py
@@ -22,6 +22,7 @@ class ModelDataTracer(metaclass=SingletonMeta):
 
     def __init__(self) -> None:
         self._cuda_usage = 0
+        self._cpu_usage = 0
         self._start_flag = False
 
     def start(self) -> None:
@@ -30,22 +31,33 @@ class ModelDataTracer(metaclass=SingletonMeta):
     def close(self) -> None:
         self._start_flag = False
 
-    def add_tensor(self, t: torch.Tensor) -> None:
+    def add_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
         if not self._start_flag:
             return
-        assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
-        mem_use = _col_tensor_mem_usage(t)
-        self._cuda_usage += mem_use
+        t_payload = t.payload if isinstance(t, ShardedTensor) else t
+        mem_use = _col_tensor_mem_usage(t_payload)
+        if t_payload.device.type == 'cuda':
+            self._cuda_usage += mem_use
+        elif t_payload.device.type == 'cpu':
+            self._cpu_usage += mem_use
+        else:
+            raise TypeError
 
-    def delete_tensor(self, t: torch.Tensor) -> None:
+    def delete_tensor(self, t: Union[torch.Tensor, ShardedTensor]) -> None:
         if not self._start_flag:
             return
-        assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
-        mem_use = _col_tensor_mem_usage(t)
-        self._cuda_usage -= mem_use
+        t_payload = t.payload if isinstance(t, ShardedTensor) else t
+        mem_use = _col_tensor_mem_usage(t_payload)
+        if t_payload.device.type == 'cuda':
+            self._cuda_usage -= mem_use
+        elif t_payload.device.type == 'cpu':
+            self._cpu_usage -= mem_use
+        else:
+            raise TypeError
 
     def clear(self) -> None:
         self._cuda_usage = 0
+        self._cpu_usage = 0
 
     @property
     def cpu_usage(self):
diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py
index 52bb487d0..df41ac95d 100644
--- a/colossalai/utils/memory_utils/utils.py
+++ b/colossalai/utils/memory_utils/utils.py
@@ -3,7 +3,7 @@ from colossalai.utils import get_current_device
 from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
 from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
 
-from typing import Union, Optional
+from typing import Union
 
 _GLOBAL_CUDA_MEM_FRACTION = 1.0
 
@@ -52,11 +52,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
         tgt_t_payload = tgt_t.data
     tgt_dev = tgt_t_payload.device
 
-    if src_dev.type == 'cuda' and tgt_dev.type == 'cpu':
-        GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
-    elif src_dev.type == 'cpu' and tgt_dev.type == 'cuda':
-        GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
+    GLOBAL_MODEL_DATA_TRACER.delete_tensor(src_t_payload)
     tgt_t_payload.copy_(src_t_payload)
+    GLOBAL_MODEL_DATA_TRACER.add_tensor(tgt_t_payload)
 
     # remove payload of src_t
     if isinstance(src_t, ShardedTensor):
@@ -65,7 +63,9 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
         src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
 
 
-def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None:
+def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
+                                       target_device: torch.device,
+                                       use_tracer: bool = True) -> None:
     """ 
     move a tensor to the target_device
     Args:
@@ -84,13 +84,11 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], ta
     # deal with torch.device('cpu') and torch.device('cpu:0)
     if t_payload.device.type == target_device.type:
         return
-
-    if target_device.type == 'cuda':
-        GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
-    elif target_device.type == 'cpu':
+    if use_tracer:
         GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
-
     t_payload.data = t_payload.data.to(target_device)
+    if use_tracer:
+        GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
 
 
 def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
@@ -115,3 +113,4 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
     # TODO() optimize the tensor moving with non-blocking
     GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
     t_payload.data = t_payload.data.cpu()
+    GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py
index 9ff4a81c5..32352e469 100644
--- a/colossalai/zero/init_ctx/init_context.py
+++ b/colossalai/zero/init_ctx/init_context.py
@@ -177,13 +177,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
 
             self.initialized_param_list.append(param)
 
+            GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor)
+
             if self.shard_param:
                 self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
-            if param.col_attr.sharded_data_tensor.device.type == 'cuda':
-                GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
-            # if param.col_attr.grad and self.shard_grad:
-            #     self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
-            #     GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
+
         # We must cast buffers
         # If we use BN, buffers may be on CPU and Float
         # We must cast them
diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
index 90b447de1..06683af6a 100644
--- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
+++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
@@ -7,6 +7,7 @@ from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
 from torch._utils import _flatten_dense_tensors as flatten
 
 from .tensor_shard_strategy import TensorShardStrategy
+from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
 
 
 class BucketTensorShardStrategy(TensorShardStrategy):
@@ -17,6 +18,9 @@ class BucketTensorShardStrategy(TensorShardStrategy):
     """
 
     def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
+        for t in tensor_list:
+            GLOBAL_MODEL_DATA_TRACER.delete_tensor(t)
+
         tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
         if len(tensor_list) == 0:
             return
@@ -46,3 +50,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
             t.reset_payload(gathered_payload)
             t.is_sharded = False
             offset += tensor_numels[i]
+
+        for t in tensor_list:
+            GLOBAL_MODEL_DATA_TRACER.add_tensor(t)
diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py
index 31210a190..25914f6f3 100644
--- a/colossalai/zero/shard_utils/tensor_shard_strategy.py
+++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py
@@ -3,13 +3,16 @@ from typing import List, Optional
 import torch
 import torch.distributed as dist
 from colossalai.utils import get_current_device
+from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
 from colossalai.zero.shard_utils import BaseShardStrategy
 from colossalai.zero.shard_utils.commons import get_shard
 from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
+from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
 
 
 class TensorShardStrategy(BaseShardStrategy):
-    """A naive implementation which shard each tensor evenly over all ranks
+    """
+    A naive implementation which shard each tensor evenly over all ranks
     """
 
     def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
@@ -21,13 +24,22 @@ class TensorShardStrategy(BaseShardStrategy):
             self._gather_tensor(t, process_group)
 
     def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
+        """ Shard tensor among processes.
+
+        Args:
+            t (ShardedTensor): a tensor to be sharded.
+            process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. 
+            Defaults to None.
+        """
         if t.is_sharded:
             return
         if t.payload.device.type == 'cuda':
             assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
                 f" but current cuda device is {get_current_device()}"
+        GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
         sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
         t.reset_payload(sharded_payload)
+        GLOBAL_MODEL_DATA_TRACER.add_tensor(t.payload)
         t.is_sharded = True
 
     def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
@@ -44,8 +56,10 @@ class TensorShardStrategy(BaseShardStrategy):
             else:
                 buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
 
+        GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
         dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
         gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
         t.reset_payload(gathered_payload)
-        t.to(target_device)
+        colo_model_data_tensor_move_inline(t, target_device, use_tracer=False)
+        GLOBAL_MODEL_DATA_TRACER.delete_tensor(t.payload)
         t.is_sharded = False
diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py
index cde257d77..c678f22da 100644
--- a/colossalai/zero/sharded_param/sharded_tensor.py
+++ b/colossalai/zero/sharded_param/sharded_tensor.py
@@ -56,7 +56,10 @@ class ShardedTensor(object):
         return self._origin_dtype
 
     def to(self, device: torch.device):
-        self._payload = self._payload.to(device)
+        raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
+
+    def to_(self, device: torch.device):
+        raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
 
     @property
     def shape(self):
diff --git a/tests/test_utils/test_tensor_move.py b/tests/test_utils/test_tensor_move.py
new file mode 100644
index 000000000..223db83ad
--- /dev/null
+++ b/tests/test_utils/test_tensor_move.py
@@ -0,0 +1,66 @@
+import pytest
+
+from colossalai.utils.cuda import get_current_device
+from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
+from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
+from colossalai.zero.sharded_param import ShardedTensor
+
+import colossalai
+
+import torch
+
+from functools import partial
+import torch.multiprocessing as mp
+from colossalai.utils import free_port
+
+
+def _run_colo_model_data_tensor_move_inline():
+    assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
+    GLOBAL_MODEL_DATA_TRACER.start()
+
+    for t in [torch.randn(2, 3), ShardedTensor(torch.randn(2, 3))]:
+        GLOBAL_MODEL_DATA_TRACER.add_tensor(t)
+        assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
+        assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
+        colo_model_data_tensor_move_inline(t, torch.device(f"cuda:{get_current_device()}"))
+        assert t.device == torch.device(f"cuda:{get_current_device()}")
+        assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
+        assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
+        GLOBAL_MODEL_DATA_TRACER.clear()
+
+    GLOBAL_MODEL_DATA_TRACER.close()
+
+
+def _run_colo_model_data_tensor_move():
+    assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
+    GLOBAL_MODEL_DATA_TRACER.start()
+
+    for t in [(torch.ones(2, 3), torch.zeros(2, 3).cuda(get_current_device())),
+              (ShardedTensor(torch.ones(2, 3)), ShardedTensor(torch.zeros(2, 3).cuda(get_current_device())))]:
+        cpu_t, cuda_t = t
+        GLOBAL_MODEL_DATA_TRACER.add_tensor(cpu_t)
+        assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 2 * 3 * 4
+        assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0
+        colo_model_data_tensor_move(cpu_t, cuda_t)
+        assert GLOBAL_MODEL_DATA_TRACER.cpu_usage == 0
+        assert GLOBAL_MODEL_DATA_TRACER.cuda_usage == 2 * 3 * 4
+        GLOBAL_MODEL_DATA_TRACER.clear()
+
+    GLOBAL_MODEL_DATA_TRACER.close()
+
+
+def run_dist(rank, world_size, port):
+    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    _run_colo_model_data_tensor_move_inline()
+    _run_colo_model_data_tensor_move()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [1, 4])
+def test_tensor_move(world_size):
+    run_func = partial(run_dist, world_size=world_size, port=free_port())
+    mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+    test_tensor_move(4)
diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py
index 4b5d9edd8..381612d1f 100644
--- a/tests/test_zero_data_parallel/test_init_context.py
+++ b/tests/test_zero_data_parallel/test_init_context.py
@@ -48,6 +48,8 @@ def run_model_test(init_device_type, shard_strategy_class):
                 f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
         if init_device.type == 'cuda':
             assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0)
+        else:
+            assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0)
         GLOBAL_MODEL_DATA_TRACER.clear()
 
 
@@ -65,5 +67,4 @@ def test_zero_init_context(world_size):
 
 
 if __name__ == '__main__':
-    # test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy)
     test_zero_init_context(4)