[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)

* polish chunk manager

* polish unit test

* impl add_extern_static_tensor for chunk mgr

* add mem stats collector v2

* polish code

* polish unit test

* polish code

* polish get chunks
pull/1094/head
ver217 2022-06-09 20:56:34 +08:00 committed by GitHub
parent b3a03e4bfd
commit be01db37c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 31 deletions

View File

@ -1,6 +1,7 @@
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.utils.memory import colo_device_memory_used
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.tensor import ChunkManager
import torch
import time
@ -128,3 +129,19 @@ class MemStatsCollector:
self._start_flag = False
self._step_idx = 0
self._step_total = 0
class MemStatsCollectorV2(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self._chunk_manager = chunk_manager
def sample_model_data(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._model_data_cuda_list.append(cuda_mem)
self._model_data_cpu_list.append(cpu_mem)

View File

@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
for p in self.module.parameters():
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
@ -137,8 +137,8 @@ class ColoDDPV2(ColoDDP):
grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
chunk = self.chunk_manager.get_chunk(p)
reduced = self.chunk_manager.reduce_chunk(p)
self.chunk_manager.release_chunk(p)
reduced = self.chunk_manager.reduce_chunk(chunk)
self.chunk_manager.release_chunk(chunk)
if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan
return empty_grad

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, Deque, Set, List
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
@ -172,6 +172,12 @@ class Chunk:
def device_type(self) -> str:
return self.data.device.type
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
class ChunkManager:
@ -226,8 +232,7 @@ class ChunkManager:
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
return src_rank
def access_chunk(self, tensor: torch.Tensor) -> None:
chunk = self.tensor_chunk_map[tensor]
def access_chunk(self, chunk: Chunk) -> None:
if chunk in self.accessed_chunks:
return
if not chunk.is_free:
@ -236,10 +241,9 @@ class ChunkManager:
self.accessed_chunks.add(chunk)
self.total_mem[chunk.device_type] += chunk.mem
def release_chunk(self, tensor: torch.Tensor) -> None:
def release_chunk(self, chunk: Chunk) -> None:
if not self.enable_distributed_storage:
return
chunk = self.tensor_chunk_map[tensor]
if chunk not in self.accessed_chunks:
return
if chunk.can_release:
@ -248,8 +252,7 @@ class ChunkManager:
if chunk.is_free:
self.total_mem[chunk.device_type] -= chunk.mem
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
chunk = self.tensor_chunk_map[tensor]
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
if chunk.data.device == device:
return
if chunk.can_move_device and not chunk.is_free:
@ -261,8 +264,7 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
def reduce_chunk(self, chunk: Chunk) -> bool:
if not chunk.can_reduce:
return False
self.total_mem[chunk.device_type] -= chunk.mem
@ -274,10 +276,6 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
chunk.copy_tensor_to_chunk_slice(tensor, data)
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
chunk = self.tensor_chunk_map[tensor]
return chunk.is_free
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
return self.tensor_chunk_map[tensor]
@ -285,8 +283,8 @@ class ChunkManager:
self.lazy_release_tensors.extend(tensors)
def exec_lazy_release(self) -> None:
for tensor in self.lazy_release_tensors:
self.release_chunk(tensor)
for chunk in self.get_chunks(self.lazy_release_tensors):
self.release_chunk(chunk)
self.lazy_release_tensors.clear()
def __repr__(self) -> str:
@ -340,3 +338,23 @@ class ChunkManager:
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
if not dest_chunk.is_free:
dest_chunk.copy_(src_chunk)
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
chunks = []
for tensor in tensors:
chunk = self.get_chunk(tensor)
if chunk not in chunks:
chunks.append(chunk)
return tuple(chunks)
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
"""Add extern static tensor to chunk manager.
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
They are "static", which means their shape, dtype, device never change.
Thus, their memory usage never changes.
Args:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()

View File

@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
for p in params:
self._chunk_manager.access_chunk(p)
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
def post_op(self, params):
for p in params:

View File

@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def _update_params_ptr(self):
for group in self.optim.param_groups:
for p in group['params']:
if not self.module.chunk_manager.is_chunk_free(p):
if not self.module.chunk_manager.get_chunk(p).is_free:
p.data = self.fp16_param_to_fp32_param[p]
else:
assert p.grad is None

View File

@ -32,7 +32,7 @@ HAS_TENSORS = {
}
}
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
@parameterize('use_chunk', [False, True])
@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero):
rank = gpc.get_local_rank(ParallelMode.DATA)
if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(32, 32) for _ in range(3)]
chunk_size = 2048 if use_chunk else None
params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero):
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
for p in params:
chunk_manager.access_chunk(p)
chunks = chunk_manager.get_chunks(params)
for chunk in chunks:
chunk_manager.access_chunk(chunk)
check_has_params(params, [True, True, True])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
for p in params:
chunk_manager.release_chunk(p)
for chunk in chunks:
chunk_manager.release_chunk(chunk)
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
for p in params:
chunk_manager.move_chunk(p, torch.device('cpu'))
for chunk in chunks:
chunk_manager.move_chunk(chunk, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
assert chunk_manager.total_mem['cuda'] == 0