From adebb3e04134393120ff32f6adda9b5ff477a993 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 15 Mar 2022 12:02:19 +0800 Subject: [PATCH] [zero] cuda margin space for OS (#418) --- colossalai/utils/commons/memory.py | 10 +++++++++ .../zero/sharded_model/sharded_model_v2.py | 22 ++++++++++++++++--- .../test_sharded_optim_v2.py | 1 - 3 files changed, 29 insertions(+), 4 deletions(-) create mode 100644 colossalai/utils/commons/memory.py diff --git a/colossalai/utils/commons/memory.py b/colossalai/utils/commons/memory.py new file mode 100644 index 000000000..9754ae6d2 --- /dev/null +++ b/colossalai/utils/commons/memory.py @@ -0,0 +1,10 @@ +import imp +import torch +from colossalai.utils import get_current_device + + +def col_cuda_memory_capacity(): + """ + Get cuda memory capacity of the current cuda. + """ + return torch.cuda.get_device_properties(get_current_device()).total_memory diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 87ddb9c63..f92107e6c 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -21,6 +21,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.allocator import col_move_to_cpu from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor) +from colossalai.utils.commons.memory import col_cuda_memory_capacity class ShardedModelV2(nn.Module): @@ -89,6 +90,12 @@ class ShardedModelV2(nn.Module): self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb) self._require_backward_grad_sync: bool = True + self._cuda_margin_space = 0 + + @property + def cuda_margin_space(self): + return self._cuda_margin_space + @property def cpu_offload(self): return self._cpu_offload @@ -103,18 +110,27 @@ class ShardedModelV2(nn.Module): def backward(self, loss): loss.backward() - self._final_backward_hook() + self._post_backward_operations() def backward_by_grad(self, tensor, grad): torch.autograd.backward(tensors=tensor, grad_tensors=grad) - self._final_backward_hook() + self._post_backward_operations() @torch.no_grad() - def _final_backward_hook(self) -> None: + def _post_backward_operations(self) -> None: + """ + The method includes operations required to be processed after backward + """ if self._iter_cnter == 0 and self._memstats_collector: self._memstats_collector.finish_collection() if self._memstats_collector: self._memstats_collector.reset_sampling_cnter() + # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. + # the way to calculate margin space is based on the assumption that + # model data is fixed in cuda during training. + # cuda margin space can be used to store OS. + self._cuda_margin_space = col_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda) + self._iter_cnter += 1 if self._require_backward_grad_sync: diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index aeaa7afaf..bded5084e 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -12,7 +12,6 @@ from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2 from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Adam from common import CONFIG, check_sharded_params_padding