[gemini] zero supports gemini (#1093)

* add placement policy

* add gemini mgr

* update mem stats collector

* update zero

* update zero optim

* fix bugs

* zero optim monitor os

* polish unit test

* polish unit test

* add assert
pull/1098/head
ver217 2022-06-10 14:48:28 +08:00 committed by GitHub
parent 2b2dc1c86b
commit 1f894e033f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 366 additions and 12 deletions

View File

@ -0,0 +1,113 @@
import functools
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
from colossalai.tensor.chunk import Chunk, ChunkManager
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
class GeminiManager:
"""
Stateful Tensor Manager, inspired from PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
# TODO: remove assert
assert placement_policy == 'cuda', 'placement_policy can only be "cuda" now'
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
self._warmup = True
def pre_iter(self):
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.start_collection()
def post_iter(self):
"""This function must be called when each iteration finishes
"""
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection()
self._warmup = False
self._compute_idx = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name)
self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx)
self._cpu_gpu_move_volume += vol
self._evict_time += evict_time
# move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand
@property
def cpu_gpu_move_volume(self):
return self._cpu_gpu_move_volume
# @functools.lru_cache(maxsize=None)
# TODO: test lru
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
cuda_demand = 0
for chunk in chunks:
if chunk.device_type == 'cpu' or chunk.is_free:
cuda_demand += chunk.mem
can_evict_chunks = []
for chunk in self._chunk_manager.chunk_groups[group_name]:
if not chunk.is_free and chunk.device_type == 'cuda' and chunk.can_move_device:
can_evict_chunks.append(chunk)
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
self._compute_list.append(chunks)
@property
def default_device(self):
return self._placement_policy.get_default_device()
def sample_overall_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data()
def sample_model_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_model_data()
@property
def chunk_manager(self):
return self._chunk_manager
@property
def cuda_margin_mem(self) -> Optional[float]:
if self._mem_stats_collector:
return self._mem_stats_collector.cuda_margin_mem
return None
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats

View File

@ -1,5 +1,6 @@
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.tensor import ChunkManager
@ -145,3 +146,7 @@ class MemStatsCollectorV2(MemStatsCollector):
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)
@property
def cuda_margin_mem(self) -> float:
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))

View File

@ -0,0 +1,157 @@
from abc import ABC, abstractmethod
from time import time
from typing import List, Optional, Tuple, Dict
import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import Type
import functools
from colossalai.tensor.chunk import Chunk, ChunkManager
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None:
raise NotImplementedError
@staticmethod
def get_default_device() -> torch.device:
return torch.device('cpu')
class CPUPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
volume = 0
for chunk in can_evict_chunks:
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.mem
return volume, 0
class CUDAPlacementPolicy(PlacementPolicy):
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
return 0, 0
@staticmethod
def get_default_device() -> torch.device:
return get_current_device()
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
# TODO(ver217): make these args configurable
self._warmup_non_model_data_ratio: float = 0.8
self._steady_cuda_cap_ratio: float = 0.9
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: List[Tuple[Chunk, ...]] = [],
compute_idx: int = 0,
**kwargs) -> int:
"""
Evict tensors from CUDA device.
Args:
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
end = time()
if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_chunks = can_evict_chunks
if not warmup:
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
# print(self._sort_can_evict_chunks.cache_info())
end = time()
for chunk in to_free_chunks:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += chunk.mem
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
return freed_cuda_model_data, end - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
class PlacementPolicyFactory:
policies: Dict[str, PlacementPolicy] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy
}
@staticmethod
def create(policy_name: str) -> Type[PlacementPolicy]:
if policy_name not in PlacementPolicyFactory.policies:
raise TypeError(f"Unknown tensor placement policy {policy_name}")
return PlacementPolicyFactory.policies[policy_name]
@staticmethod
def get_polocy_names():
return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
policy_cls = PlacementPolicyFactory.create(policy_name)
return policy_cls.get_default_device()

View File

@ -4,8 +4,11 @@ from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk
from colossalai.tensor.param_op_hook import use_param_op_hooks
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict
from colossalai.logging import get_dist_logger
def free_storage(data: torch.Tensor) -> None:
@ -89,12 +92,14 @@ class ColoDDP(torch.nn.Module):
class ColoDDPV2(ColoDDP):
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None:
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module)
self.chunk_manager = chunk_manager
self.param_op_hook = ZeROHookV2(chunk_manager)
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)
self.fp32_params = []
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
# TODO: get param order and filter unused params
for p in module.parameters():
assert p.dtype == torch.half
@ -102,22 +107,32 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter()
with use_param_op_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
return outputs
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
def _setup_grads_ptr(self):
for p in self.module.parameters():
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
self._setup_grads_ptr()
self._logger.info(
f'layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, PCIE move vol: {self.gemini_manager._cpu_gpu_move_volume}B'
)
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
@ -141,7 +156,12 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device

View File

@ -178,6 +178,9 @@ class Chunk:
def __eq__(self, __o: object) -> bool:
return self is __o
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
class ChunkManager:
@ -234,6 +237,10 @@ class ChunkManager:
def access_chunk(self, chunk: Chunk) -> None:
if chunk in self.accessed_chunks:
if chunk.device_type != 'cuda':
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(get_current_device())
self.total_mem[chunk.device_type] += chunk.mem
return
if not chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem

View File

@ -5,6 +5,7 @@ from enum import Enum
from typing import List
from contextlib import contextmanager
from functools import partial
from colossalai.gemini.gemini_mgr import GeminiManager
class TrainingPhase(Enum):
@ -14,9 +15,10 @@ class TrainingPhase(Enum):
class ZeROHookV2(ParamOpHook):
def __init__(self, chunk_manager: ChunkManager) -> None:
def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
@ -24,9 +26,11 @@ class ZeROHookV2(ParamOpHook):
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks, 'fp16_param')
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
self._gemini_manager.sample_model_data()
def post_op(self, params):
for p in params:

View File

@ -7,6 +7,7 @@ from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
class OptimState(Enum):
@ -19,6 +20,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def __init__(self,
optim: Optimizer,
module: ColoDDPV2,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
@ -29,6 +31,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
super().__init__(optim)
assert isinstance(module, ColoDDPV2)
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager = self.gemini_manager.chunk_manager
self.optim_state = OptimState.UNSCALED
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
for p, fp32_p in zip(module.parameters(), module.fp32_params):
@ -45,6 +49,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr(
optim, 'num_fp32_shards_per_param', 0) >= 2
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
self._register_states = disposable(self._register_states_)
def _update_params_ptr(self):
for group in self.optim.param_groups:
for p in group['params']:
@ -82,6 +98,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
@ -94,6 +111,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
return
self._update_params_ptr()
ret = self.optim.step(*args, **kwargs)
self._register_states()
self._update_fp16_params()
return ret
@ -109,3 +127,29 @@ class ZeroOptimizer(ColossalaiOptimizer):
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self):
if self._should_move_fp32_params_h2d:
self._should_move_fp32_params_h2d = False
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_params_used_cuda_margin_mem = 0
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
self.chunk_manager.chunk_groups['fp32_param']):
if fp32_param_chunk.is_free:
continue
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
# stores grad now
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
self.module._setup_grads_ptr()
def _register_states_(self):
for group in self.optim.param_groups:
for p in group['params']:
state = self.optim.state[p]
for val in state.values():
if isinstance(val, torch.Tensor):
self.chunk_manager.add_extern_static_tensor(val)

View File

@ -14,6 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
def check_param_equal(model, torch_model):
@ -44,7 +45,8 @@ def run_gpt(use_chunk, use_zero):
model = model.half()
chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager)
gemini_manager = GeminiManager('cuda', chunk_manager)
model = ColoDDPV2(model, gemini_manager)
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
print(chunk_manager)
check_param_equal(model, torch_model)

View File

@ -18,6 +18,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
def check_param_equal(model, torch_model):
@ -53,7 +54,8 @@ def run_gpt(use_chunk, use_zero):
chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager)
gemini_manager = GeminiManager('cuda', chunk_manager)
model = ColoDDPV2(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32)