diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index 98bd2927a..23b94c7d5 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -1,18 +1,11 @@ from colossalai.context.singleton_meta import SingletonMeta from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage import torch from typing import Union, Tuple, Optional from colossalai.logging import DistributedLogger -def _col_tensor_mem_usage(t: Union[torch.Tensor, ShardedTensor]) -> int: - if isinstance(t, ShardedTensor): - target = t.payload - else: - target = t - return target.numel() * target.element_size() - - def col_model_data_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: """ Trace the model memory usage. diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py index 383f386fa..d5fe73933 100644 --- a/colossalai/utils/memory_utils/utils.py +++ b/colossalai/utils/memory_utils/utils.py @@ -1,12 +1,32 @@ +from psutil import cpu_count import torch from colossalai.utils import get_current_device from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from typing import Union +from typing import Tuple, Union _GLOBAL_CUDA_MEM_FRACTION = 1.0 +def colo_tensor_mem_usage(tensor: Union[torch.Tensor, ShardedTensor]) -> Tuple[int, int]: + if isinstance(tensor, ShardedTensor): + t = tensor.payload + elif isinstance(tensor, torch.Tensor): + t = tensor + else: + return 0, 0 + + cuda_use, cpu_use = 0, 0 + + mem_use = t.numel() * t.element_size() + if t.device.type == 'cuda': + cuda_use += mem_use + elif t.device.type == 'cpu': + cpu_use += mem_use + + return cuda_use, cpu_use + + def colo_set_process_memory_fraction(ratio: float) -> None: """colo_set_process_memory_fraction diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 3ba5fa4bd..2c916434f 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import Dict, Optional +from os import stat +from typing import Dict, Optional, Tuple import torch import torch.distributed as dist @@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from torch.optim import Optimizer from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move +from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage class OptimState(Enum): @@ -26,14 +27,20 @@ class OptimState(Enum): class ShardedOptimizerV2(ColossalaiOptimizer): """A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO). + By default the ZeRO optimizer stage 3 offload Optimizer States on CPU. + We apply the Device-aware Operator Placement technique for OS placement from the following paper. + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management https://arxiv.org/abs/2108.05818 + GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory, which is detected by a runtime memory tracer. + We place as many OS chunks in the margin space as possible. - The size of margin space can be controlled by `gpu_margin_mem_ratio` + + The size of margin space can be controlled by `gpu_margin_mem_ratio`。 If it is set as 0.0, it is the same as classical ZeRO optimizer. NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`. @@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): hysteresis=hysteresis, max_scale=max_scale) self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) - self._logger = get_dist_logger() + self._logger = get_dist_logger("ShardedOptimizerV2") # Store fp32 param shards self.master_params: Dict[Parameter, Tensor] = {} @@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # So we gather here self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) + self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", + ranks=[0]) + + def get_memory_usage(self) -> Tuple[int, int]: + """ + Get the memory usage of the optimizer. Including master_params (param fp32), + momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq']) + + Returns: + Tuple[int, int]: cuda/cpu memory usage in Byte. + """ + cuda_use = 0 + cpu_use = 0 + + def update_mem_use(t): + nonlocal cuda_use + nonlocal cpu_use + t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t) + cuda_use += t_cuda_use + cpu_use += t_cpu_use + + for _, p_fp32 in self.master_params.items(): + update_mem_use(p_fp32) + for group in self.optim.param_groups: + for p in group['params']: + state = self.optim.state[p] + for k, v in state.items(): + update_mem_use(v) + + return cuda_use, cpu_use + def step(self, *args, **kwargs): self._maybe_move_fp32_shards() @@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.grad_scaler.update(found_inf) if found_inf: - self._logger.info('found inf during ShardedOptimV2 step') + self._logger.warning('found inf during ShardedOptimV2 step') self.zero_grad() return @@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Now p.data is sharded # So optimizer states are sharded naturally + self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", + ranks=[0]) + ret = self.optim.step(*args, **kwargs) + self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", + ranks=[0]) # Copy master param data (fp32) to payload of col_attr (fp16) # TODO() improve efficiency by gathering tensors into a chunk and transfering # a chunk. diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index e73f6cc7c..64ef16555 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist from colossalai.zero.sharded_param import ShardedTensor from typing import Optional, Tuple +from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage class ShardedParamV2(object): @@ -55,10 +56,9 @@ class ShardedParamV2(object): assert isinstance(t, torch.Tensor) nonlocal cuda_mem_use nonlocal cpu_mem_use - if t.device.type == 'cpu': - cpu_mem_use += t.numel() * t.element_size() - elif t.device.type == 'cuda': - cuda_mem_use += t.numel() * t.element_size() + t_cuda, t_cpu = colo_tensor_mem_usage(t) + cuda_mem_use += t_cuda + cpu_mem_use += t_cpu address_set = set() _update_mem_use(self.sharded_data_tensor.payload) diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 5cb5ddae6..76669e94d 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -1,6 +1,7 @@ from functools import partial import colossalai +from colossalai.utils.cuda import get_current_device import pytest import torch import torch.distributed as dist @@ -57,11 +58,12 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext(convert_fp16=True, - target_device=torch.device(f'cpu:0'), - shard_strategy=shard_strategy, - shard_param=True, - rm_torch_payload_on_the_fly=False): + with ZeroInitContext( + convert_fp16=True, + target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'), + shard_strategy=shard_strategy, + shard_param=True, + rm_torch_payload_on_the_fly=False): zero_model = model_builder(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy,