[hotfix] fix memory leak in zero (#781)

pull/785/head
HELSON 2022-04-18 13:57:03 +08:00 committed by GitHub
parent 4b01da24cd
commit 4c4388c46e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 32 additions and 36 deletions

View File

@ -12,7 +12,7 @@ __all__ = ['BaseGradScaler']
class BaseGradScaler(ABC):
def __init__(self, initial_scale: int, verbose: bool):
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._verbose = verbose
@ -31,6 +31,7 @@ class BaseGradScaler(ABC):
def state_dict(self) -> Dict:
state_dict = dict()
state_dict['scale'] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
self._scale = state_dict['scale']

View File

@ -3,6 +3,7 @@
import torch
from .base_grad_scaler import BaseGradScaler
from typing import Optional
__all__ = ['DynamicGradScaler']
@ -10,12 +11,12 @@ __all__ = ['DynamicGradScaler']
class DynamicGradScaler(BaseGradScaler):
def __init__(self,
initial_scale: int = 2**16,
growth_factor: int = 2,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
min_scale: int = None,
max_scale: int = None,
min_scale: Optional[float] = None,
max_scale: Optional[float] = None,
hysteresis: int = 2,
verbose: bool = False):
super().__init__(initial_scale, verbose)

View File

@ -358,8 +358,8 @@ class ShardedModelV2(nn.Module):
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.reset_grad_payload(grad)
param.colo_attr.reset_data_payload(grad) # release the memory of param
param.colo_attr.reset_grad_payload(grad.data)
param.colo_attr.reset_data_payload(grad.data) # release the memory of param
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True

View File

@ -83,11 +83,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
mp_process_group: Optional[ProcessGroup] = None,
verbose: bool = False) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer)
@ -115,14 +116,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
max_scale=max_scale)
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
self._verbose = verbose
# Store fp32 param shards
self._register_master_weight()
if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy,
AutoTensorPlacementPolicy):
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"')
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!",
ranks=[0])
if self._verbose:
self._logger.debug(
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])
self._use_memory_tracer = self.model.use_memory_tracer
if self._use_memory_tracer:
@ -193,15 +197,20 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._point_param_fp16_to_master_param()
self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!",
ranks=[0])
if self._verbose:
gpu_mem, cpu_mem = self.get_memory_usage()
self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
ranks=[0])
ret = self.optim.step(*args, **kwargs)
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!",
ranks=[0])
if self._verbose:
gpu_mem, cpu_mem = self.get_memory_usage()
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
ranks=[0])
self._copy_master_model_to_model_fp16()
return ret

View File

@ -5,18 +5,13 @@ from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState
from typing import List
# use this tensor as empty data point for parameters
# we do not want users use param.data when its torch payload is removed
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')
EMPTY_TENSOR_DICT = {}
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
key = (device, dtype)
if key not in EMPTY_TENSOR_DICT:
EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype)
EMPTY_TENSOR_DICT[key] = torch.empty(0, dtype=dtype, device=device)
return EMPTY_TENSOR_DICT[key]

View File

@ -72,23 +72,13 @@ def run_stm():
# warmup done
# only 2 params can be on CUDA
limit_cuda_memory(0.26)
limit_cuda_memory(0.26 / tensor_placement_policy._steady_cuda_cap_ratio)
# use OPT-like eviction strategy
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_model_data()
mem_collector.sample_overall_data()
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_model_data()
mem_collector.finish_collection()
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],