mirror of https://github.com/hpcaitech/ColossalAI
[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)
* polish chunk manager * polish unit test * impl add_extern_static_tensor for chunk mgr * add mem stats collector v2 * polish code * polish unit test * polish code * polish get chunkspull/1094/head
parent
b3a03e4bfd
commit
be01db37c8
|
@ -1,6 +1,7 @@
|
|||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.tensor import ChunkManager
|
||||
|
||||
import torch
|
||||
import time
|
||||
|
@ -128,3 +129,19 @@ class MemStatsCollector:
|
|||
self._start_flag = False
|
||||
self._step_idx = 0
|
||||
self._step_total = 0
|
||||
|
||||
|
||||
class MemStatsCollectorV2(MemStatsCollector):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager) -> None:
|
||||
super().__init__()
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
def sample_model_data(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag:
|
||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||
self._model_data_cuda_list.append(cuda_mem)
|
||||
self._model_data_cpu_list.append(cpu_mem)
|
||||
|
|
|
@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
|
|||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
for p in self.module.parameters():
|
||||
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
|
||||
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
p.grad = p.data
|
||||
|
@ -137,8 +137,8 @@ class ColoDDPV2(ColoDDP):
|
|||
grad = grad / self.dp_world_size
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
||||
chunk = self.chunk_manager.get_chunk(p)
|
||||
reduced = self.chunk_manager.reduce_chunk(p)
|
||||
self.chunk_manager.release_chunk(p)
|
||||
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if reduced and not chunk.is_free:
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
return empty_grad
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Deque, Set, List
|
||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||
from collections import deque
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
@ -172,6 +172,12 @@ class Chunk:
|
|||
def device_type(self) -> str:
|
||||
return self.data.device.type
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(id(self))
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
|
@ -226,8 +232,7 @@ class ChunkManager:
|
|||
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
|
||||
return src_rank
|
||||
|
||||
def access_chunk(self, tensor: torch.Tensor) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
if not chunk.is_free:
|
||||
|
@ -236,10 +241,9 @@ class ChunkManager:
|
|||
self.accessed_chunks.add(chunk)
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
|
||||
def release_chunk(self, tensor: torch.Tensor) -> None:
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
if not self.enable_distributed_storage:
|
||||
return
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
if chunk not in self.accessed_chunks:
|
||||
return
|
||||
if chunk.can_release:
|
||||
|
@ -248,8 +252,7 @@ class ChunkManager:
|
|||
if chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
||||
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
||||
if chunk.data.device == device:
|
||||
return
|
||||
if chunk.can_move_device and not chunk.is_free:
|
||||
|
@ -261,8 +264,7 @@ class ChunkManager:
|
|||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.tensor_trans_state(tensor, state)
|
||||
|
||||
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||
if not chunk.can_reduce:
|
||||
return False
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
@ -274,10 +276,6 @@ class ChunkManager:
|
|||
chunk = self.tensor_chunk_map[tensor]
|
||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||
|
||||
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
|
||||
chunk = self.tensor_chunk_map[tensor]
|
||||
return chunk.is_free
|
||||
|
||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
return self.tensor_chunk_map[tensor]
|
||||
|
||||
|
@ -285,8 +283,8 @@ class ChunkManager:
|
|||
self.lazy_release_tensors.extend(tensors)
|
||||
|
||||
def exec_lazy_release(self) -> None:
|
||||
for tensor in self.lazy_release_tensors:
|
||||
self.release_chunk(tensor)
|
||||
for chunk in self.get_chunks(self.lazy_release_tensors):
|
||||
self.release_chunk(chunk)
|
||||
self.lazy_release_tensors.clear()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -340,3 +338,23 @@ class ChunkManager:
|
|||
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
|
||||
if not dest_chunk.is_free:
|
||||
dest_chunk.copy_(src_chunk)
|
||||
|
||||
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||
chunks = []
|
||||
for tensor in tensors:
|
||||
chunk = self.get_chunk(tensor)
|
||||
if chunk not in chunks:
|
||||
chunks.append(chunk)
|
||||
return tuple(chunks)
|
||||
|
||||
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||
"""Add extern static tensor to chunk manager.
|
||||
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||
They are "static", which means their shape, dtype, device never change.
|
||||
Thus, their memory usage never changes.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||
|
|
|
@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
|
|||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._chunk_manager.exec_lazy_release()
|
||||
# TODO: evict chunks
|
||||
for p in params:
|
||||
self._chunk_manager.access_chunk(p)
|
||||
for chunk in chunks:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
def post_op(self, params):
|
||||
for p in params:
|
||||
|
|
|
@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
def _update_params_ptr(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if not self.module.chunk_manager.is_chunk_free(p):
|
||||
if not self.module.chunk_manager.get_chunk(p).is_free:
|
||||
p.data = self.fp16_param_to_fp32_param[p]
|
||||
else:
|
||||
assert p.grad is None
|
||||
|
|
|
@ -32,7 +32,7 @@ HAS_TENSORS = {
|
|||
}
|
||||
}
|
||||
|
||||
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
|
||||
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
|
@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero):
|
|||
rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
if rank == 0:
|
||||
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
|
||||
params = [torch.rand(32, 32) for _ in range(3)]
|
||||
chunk_size = 2048 if use_chunk else None
|
||||
params = [torch.rand(8, 8) for _ in range(3)]
|
||||
chunk_size = 128 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero):
|
|||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
|
||||
for p in params:
|
||||
chunk_manager.access_chunk(p)
|
||||
chunks = chunk_manager.get_chunks(params)
|
||||
for chunk in chunks:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
check_has_params(params, [True, True, True])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
|
||||
for p in params:
|
||||
chunk_manager.release_chunk(p)
|
||||
for chunk in chunks:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
for p in params:
|
||||
chunk_manager.move_chunk(p, torch.device('cpu'))
|
||||
for chunk in chunks:
|
||||
chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue