[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)

* polish chunk manager

* polish unit test

* impl add_extern_static_tensor for chunk mgr

* add mem stats collector v2

* polish code

* polish unit test

* polish code

* polish get chunks
pull/1094/head
ver217 2022-06-09 20:56:34 +08:00 committed by GitHub
parent b3a03e4bfd
commit be01db37c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 31 deletions

View File

@ -1,6 +1,7 @@
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.tensor import ChunkManager
import torch import torch
import time import time
@ -128,3 +129,19 @@ class MemStatsCollector:
self._start_flag = False self._start_flag = False
self._step_idx = 0 self._step_idx = 0
self._step_total = 0 self._step_total = 0
class MemStatsCollectorV2(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)

View File

@ -60,7 +60,7 @@ class ColoDDP(torch.nn.Module):
else: else:
ColoDDP._save_grad(p, grad) ColoDDP._save_grad(p, grad)
return empty_grad return empty_grad
else: else:
group = gpc.get_cpu_group(ParallelMode.DATA) group = gpc.get_cpu_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group) dist.all_reduce(grad, group=group)
@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
def _post_backward(self): def _post_backward(self):
self.chunk_manager.exec_lazy_release() self.chunk_manager.exec_lazy_release()
for p in self.module.parameters(): for p in self.module.parameters():
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad: if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
p.grad = None p.grad = None
else: else:
p.grad = p.data p.grad = p.data
@ -137,8 +137,8 @@ class ColoDDPV2(ColoDDP):
grad = grad / self.dp_world_size grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad) self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
chunk = self.chunk_manager.get_chunk(p) chunk = self.chunk_manager.get_chunk(p)
reduced = self.chunk_manager.reduce_chunk(p) reduced = self.chunk_manager.reduce_chunk(chunk)
self.chunk_manager.release_chunk(p) self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_free: if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan self.overflow_counter += chunk.has_inf_or_nan
return empty_grad return empty_grad

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Dict, Deque, Set, List from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque from collections import deque
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
@ -172,6 +172,12 @@ class Chunk:
def device_type(self) -> str: def device_type(self) -> str:
return self.data.device.type return self.data.device.type
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
class ChunkManager: class ChunkManager:
@ -226,8 +232,7 @@ class ChunkManager:
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
return src_rank return src_rank
def access_chunk(self, tensor: torch.Tensor) -> None: def access_chunk(self, chunk: Chunk) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk in self.accessed_chunks: if chunk in self.accessed_chunks:
return return
if not chunk.is_free: if not chunk.is_free:
@ -236,10 +241,9 @@ class ChunkManager:
self.accessed_chunks.add(chunk) self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, tensor: torch.Tensor) -> None: def release_chunk(self, chunk: Chunk) -> None:
if not self.enable_distributed_storage: if not self.enable_distributed_storage:
return return
chunk = self.tensor_chunk_map[tensor]
if chunk not in self.accessed_chunks: if chunk not in self.accessed_chunks:
return return
if chunk.can_release: if chunk.can_release:
@ -248,8 +252,7 @@ class ChunkManager:
if chunk.is_free: if chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None: def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
if chunk.data.device == device: if chunk.data.device == device:
return return
if chunk.can_move_device and not chunk.is_free: if chunk.can_move_device and not chunk.is_free:
@ -261,8 +264,7 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state) chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, tensor: torch.Tensor) -> bool: def reduce_chunk(self, chunk: Chunk) -> bool:
chunk = self.tensor_chunk_map[tensor]
if not chunk.can_reduce: if not chunk.can_reduce:
return False return False
self.total_mem[chunk.device_type] -= chunk.mem self.total_mem[chunk.device_type] -= chunk.mem
@ -274,10 +276,6 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor] chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data) chunk.copy_tensor_to_chunk_slice(tensor, data)
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
return chunk.is_free
def get_chunk(self, tensor: torch.Tensor) -> Chunk: def get_chunk(self, tensor: torch.Tensor) -> Chunk:
return self.tensor_chunk_map[tensor] return self.tensor_chunk_map[tensor]
@ -285,8 +283,8 @@ class ChunkManager:
self.lazy_release_tensors.extend(tensors) self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None: def exec_lazy_release(self) -> None:
for tensor in self.lazy_release_tensors: for chunk in self.get_chunks(self.lazy_release_tensors):
self.release_chunk(tensor) self.release_chunk(chunk)
self.lazy_release_tensors.clear() self.lazy_release_tensors.clear()
def __repr__(self) -> str: def __repr__(self) -> str:
@ -340,3 +338,23 @@ class ChunkManager:
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_free: if not dest_chunk.is_free:
dest_chunk.copy_(src_chunk) dest_chunk.copy_(src_chunk)
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()

View File

@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params): def pre_op(self, params):
chunks = self._chunk_manager.get_chunks(params)
for p in params: for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release() self._chunk_manager.exec_lazy_release()
# TODO: evict chunks # TODO: evict chunks
for p in params: for chunk in chunks:
self._chunk_manager.access_chunk(p) self._chunk_manager.access_chunk(chunk)
def post_op(self, params): def post_op(self, params):
for p in params: for p in params:

View File

@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def _update_params_ptr(self): def _update_params_ptr(self):
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
if not self.module.chunk_manager.is_chunk_free(p): if not self.module.chunk_manager.get_chunk(p).is_free:
p.data = self.fp16_param_to_fp32_param[p] p.data = self.fp16_param_to_fp32_param[p]
else: else:
assert p.grad is None assert p.grad is None

View File

@ -32,7 +32,7 @@ HAS_TENSORS = {
} }
} }
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}} TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
@parameterize('use_chunk', [False, True]) @parameterize('use_chunk', [False, True])
@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero):
rank = gpc.get_local_rank(ParallelMode.DATA) rank = gpc.get_local_rank(ParallelMode.DATA)
if rank == 0: if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}') print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(32, 32) for _ in range(3)] params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 2048 if use_chunk else None chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0 assert chunk_manager.total_mem['cuda'] == 0
@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero):
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank] assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
for p in params: chunks = chunk_manager.get_chunks(params)
chunk_manager.access_chunk(p) for chunk in chunks:
chunk_manager.access_chunk(chunk)
check_has_params(params, [True, True, True]) check_has_params(params, [True, True, True])
assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank] assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
for p in params: for chunk in chunks:
chunk_manager.release_chunk(p) chunk_manager.release_chunk(chunk)
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
for p in params: for chunk in chunks:
chunk_manager.move_chunk(p, torch.device('cpu')) chunk_manager.move_chunk(chunk, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
assert chunk_manager.total_mem['cuda'] == 0 assert chunk_manager.total_mem['cuda'] == 0