mirror of https://github.com/hpcaitech/ColossalAI
[zero] cuda margin space for OS (#418)
parent
56bb412e72
commit
adebb3e041
|
@ -0,0 +1,10 @@
|
||||||
|
import imp
|
||||||
|
import torch
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
def col_cuda_memory_capacity():
|
||||||
|
"""
|
||||||
|
Get cuda memory capacity of the current cuda.
|
||||||
|
"""
|
||||||
|
return torch.cuda.get_device_properties(get_current_device()).total_memory
|
|
@ -21,6 +21,7 @@ from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu
|
||||||
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
|
from colossalai.utils.commons.memory import col_cuda_memory_capacity
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
|
@ -89,6 +90,12 @@ class ShardedModelV2(nn.Module):
|
||||||
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
||||||
self._require_backward_grad_sync: bool = True
|
self._require_backward_grad_sync: bool = True
|
||||||
|
|
||||||
|
self._cuda_margin_space = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cuda_margin_space(self):
|
||||||
|
return self._cuda_margin_space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cpu_offload(self):
|
def cpu_offload(self):
|
||||||
return self._cpu_offload
|
return self._cpu_offload
|
||||||
|
@ -103,18 +110,27 @@ class ShardedModelV2(nn.Module):
|
||||||
|
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._final_backward_hook()
|
self._post_backward_operations()
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad):
|
||||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||||
self._final_backward_hook()
|
self._post_backward_operations()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _final_backward_hook(self) -> None:
|
def _post_backward_operations(self) -> None:
|
||||||
|
"""
|
||||||
|
The method includes operations required to be processed after backward
|
||||||
|
"""
|
||||||
if self._iter_cnter == 0 and self._memstats_collector:
|
if self._iter_cnter == 0 and self._memstats_collector:
|
||||||
self._memstats_collector.finish_collection()
|
self._memstats_collector.finish_collection()
|
||||||
if self._memstats_collector:
|
if self._memstats_collector:
|
||||||
self._memstats_collector.reset_sampling_cnter()
|
self._memstats_collector.reset_sampling_cnter()
|
||||||
|
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
||||||
|
# the way to calculate margin space is based on the assumption that
|
||||||
|
# model data is fixed in cuda during training.
|
||||||
|
# cuda margin space can be used to store OS.
|
||||||
|
self._cuda_margin_space = col_cuda_memory_capacity() - max(self._memstats_collector._overall_cuda)
|
||||||
|
|
||||||
self._iter_cnter += 1
|
self._iter_cnter += 1
|
||||||
|
|
||||||
if self._require_backward_grad_sync:
|
if self._require_backward_grad_sync:
|
||||||
|
|
|
@ -12,7 +12,6 @@ from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Adam
|
|
||||||
|
|
||||||
from common import CONFIG, check_sharded_params_padding
|
from common import CONFIG, check_sharded_params_padding
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue