Browse Source

[zero] add constant placement policy (#1705)

* fixes memory leak when paramter is in fp16 in ZeroDDP init.
* bans chunk releasement in CUDA. Only when a chunk is about to offload, it is allowed to release.
* adds a constant placement policy. With it, users can allocate a reserved caching memory space for parameters.
pull/1711/head
HELSON 2 years ago committed by GitHub
parent
commit
1468e4bcfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      colossalai/context/moe_context.py
  2. 4
      colossalai/gemini/chunk/chunk.py
  3. 47
      colossalai/gemini/chunk/manager.py
  4. 8
      colossalai/gemini/gemini_mgr.py
  5. 71
      colossalai/gemini/placement_policy.py
  6. 13
      colossalai/nn/parallel/data_parallel.py
  7. 9
      colossalai/utils/model/colo_init_context.py
  8. 4
      colossalai/zero/utils/zero_hook_v2.py
  9. 10
      tests/test_gemini/update/test_fwd_bwd.py
  10. 2
      tests/test_gemini/update/test_optim.py
  11. 1
      tests/test_gemini/update/test_zeroddp_state_dict.py

5
colossalai/context/moe_context.py

@ -55,7 +55,6 @@ class MoeContext(metaclass=SingletonMeta):
return self.has_setup
def setup(self, seed: int, use_kernel_optim: bool = True):
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
_check_sanity()
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
@ -93,8 +92,8 @@ class MoeContext(metaclass=SingletonMeta):
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number"\
" is not a multiple of ep size or vice versa."
assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
" is not a multiple of ep size or vice versa."
# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts

4
colossalai/gemini/chunk/chunk.py

@ -51,6 +51,8 @@ def alloc_storage(tensor: torch.Tensor) -> None:
class Chunk:
_total_number = 0
def __init__(self,
chunk_size: int,
process_group: ColoProcessGroup,
@ -73,6 +75,8 @@ class Chunk:
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
self.count_id = Chunk._total_number
Chunk._total_number += 1
self.chunk_size = chunk_size
self.utilized_size = 0

47
colossalai/gemini/chunk/manager.py

@ -28,7 +28,7 @@ class ChunkManager:
self.chunk_groups: Dict[str, Deque] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set()
self.lazy_release_tensors: List[torch.Tensor] = list()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None:
@ -91,9 +91,8 @@ class ChunkManager:
self.__sub_memroy_usage(chunk.memory_usage)
if chunk.device_type == 'cpu':
chunk.shard_move(get_current_device())
chunk.access_chunk()
self.__add_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
self.accessed_chunks.add(chunk)
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.
@ -102,9 +101,8 @@ class ChunkManager:
return
if chunk.can_release:
self.__sub_memroy_usage(chunk.memory_usage)
chunk.release_chunk()
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
self.accessed_chunks.remove(chunk)
def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None:
"""Move the shard of the chunk to the target device.
@ -128,6 +126,7 @@ class ChunkManager:
return False
self.__sub_memroy_usage(chunk.memory_usage)
chunk.reduce()
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
return True
@ -151,31 +150,15 @@ class ChunkManager:
"""
return self.tensor_chunk_map[tensor]
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
"""
Add tensors to the buffer for lazy release.
Args:
tensors (List[torch.Tensor]): the tensors to be released lazily
"""
self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None:
def get_cuda_movable_chunks(self) -> List[Chunk]:
"""
Execute release for tensors added to the lazy release buffer.
Get all chunks that can be moved.
"""
for chunk in self.get_chunks(self.lazy_release_tensors):
self.release_chunk(chunk)
self.lazy_release_tensors.clear()
def get_cuda_movable_chunks(self, group_type: str) -> List[Chunk]:
chunk_list = []
for group_name in self.chunk_groups:
if group_type in group_name:
for chunk in self.chunk_groups[group_name]:
if chunk.device_type == 'cuda' and chunk.can_move:
chunk_list.append(chunk)
for chunk in self.accessed_chunks:
if chunk.can_release:
chunk_list.append(chunk)
chunk_list.sort(key=lambda x: x.count_id)
return chunk_list
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
@ -235,3 +218,13 @@ class ChunkManager:
def __add_memory_usage(self, usage: Dict[str, int]):
for k, v in usage.items():
self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()
self.accessed_chunks.remove(chunk)
self.accessed_mem -= chunk.chunk_mem

8
colossalai/gemini/gemini_mgr.py

@ -56,14 +56,14 @@ class GeminiManager:
self._evict_time = 0
self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_type: str) -> None:
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_type)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
@ -78,7 +78,7 @@ class GeminiManager:
self._h2d_volume += cuda_demand
@functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_type: str):
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
start = time()
cuda_demand = 0
for chunk in chunks:
@ -93,7 +93,7 @@ class GeminiManager:
raise RuntimeError
self._comp_cuda_demand_time += time() - start
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks(group_type)
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:

71
colossalai/gemini/placement_policy.py

@ -36,8 +36,9 @@ class CPUPlacementPolicy(PlacementPolicy):
volume = 0
start = time()
for chunk in can_evict_chunks:
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.shard_mem
volume += chunk.chunk_mem
return volume, time() - start
@ -116,8 +117,9 @@ class AutoPlacementPolicy(PlacementPolicy):
if freed_cuda_model_data >= to_free_cuda_model_data:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_cuda_model_data += chunk.shard_mem
freed_cuda_model_data += chunk.chunk_mem
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}")
@ -147,11 +149,74 @@ class AutoPlacementPolicy(PlacementPolicy):
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> Tuple[int, float]:
"""
See the docstrings in the class `AutoPlacementPolicy`.
"""
start = time()
used_accessed_memory = self.chunk_manager.accessed_mem
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
freed_accessed_memory = 0
if avail_accessed_memory < cuda_demand:
to_free_memory = cuda_demand - avail_accessed_memory
to_free_chunks = can_evict_chunks
if not warmup:
# sort all chunks
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
for chunk in to_free_chunks:
if freed_accessed_memory >= to_free_memory:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_accessed_memory += chunk.chunk_mem
if freed_accessed_memory < to_free_memory:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_memory}, freed {freed_accessed_memory}")
return freed_accessed_memory, time() - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
boundary = int(cuda_memory_mb * 1024**2)
assert boundary > 0
ConstPlacementPolicy._accessed_memory_boundary = boundary
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy
'auto': AutoPlacementPolicy,
'const': ConstPlacementPolicy
}
@staticmethod

13
colossalai/nn/parallel/data_parallel.py

@ -224,14 +224,16 @@ class ZeroDDP(ColoDDP):
# TODO: get param order and filter unused params
for p in module.parameters():
assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False):
p.data = p.half()
p.data = p.data.half()
continue
dp_world_size = p.process_group.dp_world_size()
fp32_data = p.float().data
p.data = p.half()
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory)
self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory)
self.fp32_params.append(fp32_p)
@ -253,7 +255,6 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.pre_iter()
with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
@ -265,7 +266,7 @@ class ZeroDDP(ColoDDP):
p.grad = None
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
assert self.chunk_manager.accessed_mem == 0
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'

9
colossalai/utils/model/colo_init_context.py

@ -33,7 +33,10 @@ def ColoModulize(module):
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
def __init__(self,
lazy_memory_allocate: bool = False,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float):
"""
Args:
lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False.
@ -42,6 +45,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
self._dtype = dtype
self._register_colo_modules()
@ -87,7 +91,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# TODO(jiaruifang) we initialize a Default PG memory
colo_param = ColoParameter(param.to(self._device), requires_grad=requires_grad)
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
requires_grad=requires_grad)
# add mapping record
replaced_tensors[param] = colo_param
delattr(submodule, param_name)

4
colossalai/zero/utils/zero_hook_v2.py

@ -26,9 +26,8 @@ class ZeROHookV2(ParamOpHook):
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks, 'fp16_param')
self._gemini_manager.adjust_layout(chunks)
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
self._gemini_manager.sample_model_data()
@ -38,7 +37,6 @@ class ZeROHookV2(ParamOpHook):
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)
self._chunk_manager.add_lazy_release_tensors(params)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)

10
tests/test_gemini/update/test_fwd_bwd.py

@ -11,18 +11,14 @@ from functools import partial
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
from colossalai.tensor import ProcessGroup
from tests.test_tensor.common_utils import debug_print
from time import time
from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager
def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
@ -44,7 +40,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')

2
tests/test_gemini/update/test_optim.py

@ -48,7 +48,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
return logits
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
def exam_gpt_fwd_bwd(placement_policy):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')

1
tests/test_gemini/update/test_zeroddp_state_dict.py

@ -12,7 +12,6 @@ from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from tests.test_tensor.common_utils import debug_print

Loading…
Cancel
Save