diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index 938826b55..f078bbfa6 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -4,6 +4,9 @@ from colossalai.utils import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy from ._base_ophook import BaseOpHook +from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector +from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer +from typing import Optional @OPHOOKS.register_module @@ -12,14 +15,17 @@ class ZeroHook(BaseOpHook): A hook to process sharded param for ZeRO method. """ - def __init__(self, shard_strategy: BaseShardStrategy): + def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]): super().__init__() self.shard_strategy = shard_strategy # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU self.computing_device = torch.device(f'cuda:{get_current_device()}') + self._memstarts_collector = memstarts_collector + def pre_fwd_exec(self, module: torch.nn.Module, *args): tensor_list = [] + global_model_data_tracer = ModelDataTracer() for param in module.parameters(): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.data) @@ -27,8 +33,12 @@ class ZeroHook(BaseOpHook): for param in module.parameters(): if param.col_attr.data.device != self.computing_device: param.col_attr.data.to(self.computing_device) + global_model_data_tracer.add_tensor(param.col_attr.data.payload) param.data = param.col_attr.data.payload + if self._memstarts_collector: + self._memstarts_collector.sample_memstats() + def post_fwd_exec(self, module: torch.nn.Module, *args): tensor_list = [] for param in module.parameters(): @@ -40,6 +50,7 @@ class ZeroHook(BaseOpHook): def pre_bwd_exec(self, module: torch.nn.Module, input, output): tensor_list = [] + global_model_data_tracer = ModelDataTracer() for param in module.parameters(): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.data) @@ -47,6 +58,7 @@ class ZeroHook(BaseOpHook): for param in module.parameters(): if param.col_attr.data.device != self.computing_device: param.col_attr.data.to(self.computing_device) + global_model_data_tracer.add_tensor(param.col_attr.data.payload) param.data = param.col_attr.data.payload # Store local accumulated grad shard if param.grad is not None: @@ -60,6 +72,8 @@ class ZeroHook(BaseOpHook): # The grad here must be locally computed full grad in this backward pass assert param.grad.shape == param.col_attr.data.origin_shape param.col_attr.bwd_count += 1 + if self._memstarts_collector: + self._memstarts_collector.sample_memstats() def post_bwd_exec(self, module: torch.nn.Module, input): tensor_list = [] diff --git a/colossalai/utils/memory_tracer/allocator.py b/colossalai/utils/memory_tracer/allocator.py index 368aae2da..26c36ef79 100644 --- a/colossalai/utils/memory_tracer/allocator.py +++ b/colossalai/utils/memory_tracer/allocator.py @@ -1,60 +1,19 @@ import torch -from colossalai.utils.commons.singleton_meta import SingletonMeta -from colossalai.zero.sharded_param import ShardedTensor - -from typing import Union +from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer -def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int: - if isinstance(t, ShardedTensor): - target = t.payload - else: - target = t - return target.numel() * target.element_size() +def col_move_to_cpu(t: torch.Tensor): + assert isinstance(t, torch.Tensor) + if t.device.type == 'cpu': + return + + ModelDataTracer().delete_tensor(t) + t.data = t.data.cpu() -class ModelDataTracer(metaclass=SingletonMeta): - """ - A singleton to trace model data usage during runtime. - """ - - def __init__(self) -> None: - self._cpu_usage = 0 - self._cuda_usage = 0 - - def trace_tensor(self, t: torch.Tensor): - mem_use = col_tensor_mem_usage(t) - if t.device.type == 'cpu': - self._cpu_usage += mem_use - elif t.device.type == 'cuda': - self._cuda_usage += mem_use - else: - raise RuntimeError - - def detach_tensor(self, t: torch.Tensor): - mem_use = col_tensor_mem_usage(t) - if t.device.type == 'cpu': - self._cpu_usage -= mem_use - elif t.device.type == 'cuda': - self._cuda_usage -= mem_use - else: - raise RuntimeError - - @property - def cpu_usage(self): - return self._cpu_usage - - @property - def cuda_usage(self): - return self._cuda_usage - - -GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() - - -def col_allocate_payload(device: torch.device) -> torch.Tensor: +def col_modeldata_allocate(device: torch.device) -> torch.Tensor: pass -def col_release_payload(t: torch.Tensor): +def col_modeldata_release(t: torch.Tensor): pass diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py index 8f968acfb..fe65651ae 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -6,7 +6,7 @@ from colossalai.utils import get_current_device import torch -def _get_cuda_memory_used(device: torch.device) -> int: +def get_cuda_memory_used(device: torch.device) -> int: """ Get the free memory info of device. :param device: device id @@ -87,7 +87,7 @@ class AsyncMemoryMonitor: while self.keep_measuring: max_usage = max( max_usage, - _get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')), + get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')), ) sleep(self.interval) return max_usage diff --git a/colossalai/utils/memory_tracer/commons.py b/colossalai/utils/memory_tracer/commons.py new file mode 100644 index 000000000..28fc2abd3 --- /dev/null +++ b/colossalai/utils/memory_tracer/commons.py @@ -0,0 +1,11 @@ +from colossalai.zero.sharded_param import ShardedTensor +from typing import Union +import torch + + +def col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int: + if isinstance(t, ShardedTensor): + target = t.payload + else: + target = t + return target.numel() * target.element_size() diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py new file mode 100644 index 000000000..6da89f6ba --- /dev/null +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -0,0 +1,81 @@ +from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer +from .async_memtracer import get_cuda_memory_used +from colossalai.utils import get_current_device + +import torch + + +class SamplingCounter: + + def __init__(self) -> None: + self._samplint_cnt = 0 + + def advance(self): + self._samplint_cnt += 1 + + @property + def sampling_cnt(self): + return self._samplint_cnt + + def reset(self): + self._samplint_cnt = 0 + + +class MemStatsCollector: + + def __init__(self) -> None: + """ + Collecting Memory Statistics. + It has two phases. + 1. Collection Phase: collect memory usage statistics + 2. Runtime Phase: do not collect statistics. + """ + self._sampling_cnter = SamplingCounter() + self._model_data_cuda = [] + self._overall_cuda = [] + + # TODO(jiaruifang) Now no cpu mem stats collecting + self._model_data_cpu = [] + self._overall_cpu = [] + + self._start_flag = False + + def start_collection(self): + self._start_flag = True + + def finish_collection(self): + self._start_flag = False + + def sample_memstats(self) -> None: + """ + Sampling memory statistics. + Record the current model data CUDA memory usage as well as system CUDA memory usage. + """ + if self._start_flag: + sampling_cnt = self._sampling_cnter.sampling_cnt + assert sampling_cnt == len(self._overall_cuda) + self._model_data_cuda.append(ModelDataTracer().cuda_usage) + self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}'))) + self._sampling_cnter.advance() + + def fetch_memstats(self) -> (int, int): + """ + returns cuda usage of model data and overall cuda usage. + """ + sampling_cnt = self._sampling_cnter.sampling_cnt + if len(self._model_data_cuda) < sampling_cnt: + raise RuntimeError + return (self._model_data_cuda[sampling_cnt], self._overall_cuda[sampling_cnt]) + + def reset_sampling_cnter(self) -> None: + self._sampling_cnter.reset() + + def clear(self) -> None: + self._model_data_cuda = [] + self._overall_cuda = [] + + self._model_data_cpu = [] + self._overall_cpu = [] + + self._start_flag = False + self._sampling_cnter.reset() diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py new file mode 100644 index 000000000..4a3062bb3 --- /dev/null +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -0,0 +1,34 @@ +from colossalai.utils.commons.singleton_meta import SingletonMeta +from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage +import torch + + +class ModelDataTracer(metaclass=SingletonMeta): + """ + A singleton to trace model data usage during runtime. + We have to trigger our API (trace_tensor, detach_tensor) when do model-data memory operation, + including allocation, releasing and moving. + + NOTE() now the class only trace cuda memory usage + """ + + def __init__(self) -> None: + self._cuda_usage = 0 + + def add_tensor(self, t: torch.Tensor): + assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor" + mem_use = col_tensor_mem_usage(t) + self._cuda_usage += mem_use + + def delete_tensor(self, t: torch.Tensor): + assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor" + mem_use = col_tensor_mem_usage(t) + self._cuda_usage -= mem_use + + @property + def cpu_usage(self): + return self._cpu_usage + + @property + def cuda_usage(self): + return self._cuda_usage diff --git a/colossalai/utils/memory_tracer/test_memstats_collector.py b/colossalai/utils/memory_tracer/test_memstats_collector.py new file mode 100644 index 000000000..9c93600b7 --- /dev/null +++ b/colossalai/utils/memory_tracer/test_memstats_collector.py @@ -0,0 +1,43 @@ +from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector +from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer +import torch + + +def test_mem_collector(): + collector = MemStatsCollector() + + collector.start_collection() + + a = torch.randn(10).cuda() + + # sampling at time 0 + collector.sample_memstats() + + m_a = torch.randn(10).cuda() + ModelDataTracer().add_tensor(m_a) + b = torch.randn(10).cuda() + + # sampling at time 1 + collector.sample_memstats() + + a = b + + # sampling at time 2 + collector.sample_memstats() + + collector.finish_collection() + collector.reset() + + # do nothing after collection, just advance sampling cnter + collector.sample_memstats() + collector.sample_memstats() + + cuda_use, overall_use = collector.fetch_memstats() + print(cuda_use, overall_use) + + print(collector._model_data_cuda) + print(collector._overall_cuda) + + +if __name__ == '__main__': + test_mem_collector() diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 17e89cbf7..64b14c644 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -3,10 +3,11 @@ import functools import torch from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param import ShardedParamV2 -from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER - +from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer # Inserts _post_init_method at the end of init method + + # for all sub classes of torch.nn.Module class InsertPostInitMethodToModuleSubClasses(object): @@ -152,7 +153,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) - GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._data_sharded_tensor.payload) + ModelDataTracer().add_tensor(param.col_attr._data_sharded_tensor.payload) if param.col_attr.grad and self.shard_grad: self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) - GLOBAL_MODEL_DATA_TRACER.trace_tensor(param.col_attr._grad_sharded_tensor.payload) + ModelDataTracer().add_tensor(param.col_attr._grad_sharded_tensor.payload) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 7510cb68e..87ddb9c63 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter - +from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector +from colossalai.utils.memory_tracer.allocator import col_move_to_cpu from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor) @@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module): fp32_reduce_scatter: bool = False, offload_config: Optional[dict] = None, gradient_predivide_factor: Optional[float] = 1.0, - shard_param: bool = True): + shard_param: bool = True, + use_memory_tracer: bool = False): r""" A demo to reconfigure zero1 shared_model. Currently do not consider the Optimizer States. @@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module): if self.shard_param: self.shard_strategy.shard([param.col_attr.data]) + # Init Memory Statistics Collector + self._use_memory_tracer = use_memory_tracer + if self._use_memory_tracer: + self._memstats_collector = MemStatsCollector() + else: + self._memstats_collector = None + self._iter_cnter = 0 + # Register hooks - register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy)]) + register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)]) self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) @@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module): return self._cpu_offload def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + if self._iter_cnter == 0 and self._memstats_collector: + # the opeartion will affect the flag in ZeroHook + self._memstats_collector.start_collection() args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) outputs = self.module(*args, **kwargs) return outputs @@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module): @torch.no_grad() def _final_backward_hook(self) -> None: + if self._iter_cnter == 0 and self._memstats_collector: + self._memstats_collector.finish_collection() + if self._memstats_collector: + self._memstats_collector.reset_sampling_cnter() + self._iter_cnter += 1 + if self._require_backward_grad_sync: # Flush any unreduced buckets in the post_backward stream. with torch.cuda.stream(self.comm_stream): @@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module): reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data) # Maybe offload + # TODO() optimize GPU->CPU bandwidth utilization if self._cpu_offload: - reduced_grad.data = reduced_grad.data.cpu() + col_move_to_cpu(reduced_grad) + # reduced_grad.data = reduced_grad.data.cpu() if param.col_attr.grad is None: param.col_attr.grad = reduced_grad.data diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 47c0d26b7..d78ac3ecc 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -143,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # We have to use `copy_payload` instead of `reset_payload` # Since p.data is fp32 and p.col_attr.data is fp16 - # TODO() optimize this line + # TODO() optimize this line CPU (fp32) -> GPU (fp16) p.col_attr.data.copy_payload(p.data) if not is_param_sharded: diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 619ab4bdc..237b77f06 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload): assert torch.all(data.grad == data_.grad), 'Gradient of the input does not match' torch.cuda.empty_cache() + # as seed manager is singleton # if we don't reset seeds here, # other tests will fail if running together with this test diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index a74e6959d..2e6fede05 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -9,12 +9,12 @@ import torch import torch.multiprocessing as mp from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory_tracer.allocator import GLOBAL_MODEL_DATA_TRACER from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from tests.components_to_test.registry import non_distributed_component_funcs from common import CONFIG +from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer def run_dist(rank, world_size, port, init_device, shard_strategy): @@ -37,13 +37,10 @@ def run_dist(rank, world_size, port, init_device, shard_strategy): assert param.col_attr.data.payload.device.type == init_device.type, \ f'{param.col_attr.data.payload.device.type} vs. {init_device.type}' - print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}') - print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}') + print(f'cuda usgae {ModelDataTracer().cuda_usage}') print(f'numel {model_numel_tensor}') if init_device.type == 'cuda': - assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) - elif init_device.type == 'cpu': - assert (GLOBAL_MODEL_DATA_TRACER.cpu_usage > 0) + assert (ModelDataTracer().cuda_usage > 0) @pytest.mark.dist @@ -60,5 +57,5 @@ def test_zero_init_context(world_size, init_device, shard_strategy): if __name__ == '__main__': - test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy) - test_zero_init_context(2, torch.device(f'cuda:{get_current_device()}'), TensorShardStrategy) + # test_zero_init_context(2, torch.device('cpu'), TensorShardStrategy) + test_zero_init_context(4, torch.device('cpu'), BucketTensorShardStrategy) diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 54ca5ad3c..a2cae3ee5 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -18,6 +18,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP from common import CONFIG, check_grads_padding, run_fwd_bwd +from colossalai.zero.sharded_model.utils import col_model_deepcopy def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_strategy): @@ -33,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s if use_zero_init_ctx: with ZeroInitContext(convert_fp16=True, - target_device=torch.device('cpu'), + target_device=torch.device(f'cpu:0'), shard_strategy=shard_strategy, shard_param=True, rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly): zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) + zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True) model = model_builder(checkpoint=True).half() col_model_deepcopy(zero_model, model) @@ -59,6 +60,9 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast, shard_s check_grads_padding(model, zero_model, loose=True) + print('overall cuda ', zero_model._memstats_collector._overall_cuda) + print('model cuda ', zero_model._memstats_collector._model_data_cuda) + @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 9371cf66a..622df5693 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - import copy from functools import partial @@ -82,4 +79,4 @@ def test_sharded_optim_v2(world_size, cpu_offload, shard_strategy): if __name__ == '__main__': - test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy) + test_sharded_optim_v2(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy) \ No newline at end of file