[zero] cuda margin space for OS (#418)

pull/419/head
Jiarui Fang 2022-03-15 12:02:19 +08:00 committed by GitHub
parent 56bb412e72
commit adebb3e041
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 4 deletions

View File

@ -0,0 +1,10 @@
import imp
import torch
from colossalai.utils import get_current_device
def col_cuda_memory_capacity():
"""
Get cuda memory capacity of the current cuda.
"""
return torch.cuda.get_device_properties(get_current_device()).total_memory

View File

@ -21,6 +21,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
get_gradient_predivide_factor)
from colossalai.utils.commons.memory import col_cuda_memory_capacity
class ShardedModelV2(nn.Module):
@ -89,6 +90,12 @@ class ShardedModelV2(nn.Module):
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
self._require_backward_grad_sync: bool = True
self._cuda_margin_space = 0
@property
def cuda_margin_space(self):
return self._cuda_margin_space
@property
def cpu_offload(self):
return self._cpu_offload
@ -103,18 +110,27 @@ class ShardedModelV2(nn.Module):
def backward(self, loss):
loss.backward()
self._final_backward_hook()
self._post_backward_operations()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._final_backward_hook()
self._post_backward_operations()
@torch.no_grad()
def _final_backward_hook(self) -> None:
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
"""
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = col_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda)
self._iter_cnter += 1
if self._require_backward_grad_sync:

View File

@ -12,7 +12,6 @@ from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from common import CONFIG, check_sharded_params_padding