mirror of https://github.com/hpcaitech/ColossalAI
[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)
* polish chunk manager * polish unit test * impl add_extern_static_tensor for chunk mgr * add mem stats collector v2 * polish code * polish unit test * polish code * polish get chunkspull/1094/head
parent
b3a03e4bfd
commit
be01db37c8
|
@ -1,6 +1,7 @@
|
||||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||||
from colossalai.utils.memory import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
|
from colossalai.tensor import ChunkManager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
@ -128,3 +129,19 @@ class MemStatsCollector:
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._step_idx = 0
|
self._step_idx = 0
|
||||||
self._step_total = 0
|
self._step_total = 0
|
||||||
|
|
||||||
|
|
||||||
|
class MemStatsCollectorV2(MemStatsCollector):
|
||||||
|
|
||||||
|
def __init__(self, chunk_manager: ChunkManager) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._chunk_manager = chunk_manager
|
||||||
|
|
||||||
|
def sample_model_data(self) -> None:
|
||||||
|
"""Sampling model data statistics.
|
||||||
|
"""
|
||||||
|
if self._start_flag:
|
||||||
|
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||||
|
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||||
|
self._model_data_cuda_list.append(cuda_mem)
|
||||||
|
self._model_data_cpu_list.append(cpu_mem)
|
||||||
|
|
|
@ -60,7 +60,7 @@ class ColoDDP(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
ColoDDP._save_grad(p, grad)
|
ColoDDP._save_grad(p, grad)
|
||||||
return empty_grad
|
return empty_grad
|
||||||
|
|
||||||
else:
|
else:
|
||||||
group = gpc.get_cpu_group(ParallelMode.DATA)
|
group = gpc.get_cpu_group(ParallelMode.DATA)
|
||||||
dist.all_reduce(grad, group=group)
|
dist.all_reduce(grad, group=group)
|
||||||
|
@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
|
||||||
def _post_backward(self):
|
def _post_backward(self):
|
||||||
self.chunk_manager.exec_lazy_release()
|
self.chunk_manager.exec_lazy_release()
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
|
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
|
||||||
p.grad = None
|
p.grad = None
|
||||||
else:
|
else:
|
||||||
p.grad = p.data
|
p.grad = p.data
|
||||||
|
@ -137,8 +137,8 @@ class ColoDDPV2(ColoDDP):
|
||||||
grad = grad / self.dp_world_size
|
grad = grad / self.dp_world_size
|
||||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
||||||
chunk = self.chunk_manager.get_chunk(p)
|
chunk = self.chunk_manager.get_chunk(p)
|
||||||
reduced = self.chunk_manager.reduce_chunk(p)
|
reduced = self.chunk_manager.reduce_chunk(chunk)
|
||||||
self.chunk_manager.release_chunk(p)
|
self.chunk_manager.release_chunk(chunk)
|
||||||
if reduced and not chunk.is_free:
|
if reduced and not chunk.is_free:
|
||||||
self.overflow_counter += chunk.has_inf_or_nan
|
self.overflow_counter += chunk.has_inf_or_nan
|
||||||
return empty_grad
|
return empty_grad
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Dict, Deque, Set, List
|
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
|
@ -172,6 +172,12 @@ class Chunk:
|
||||||
def device_type(self) -> str:
|
def device_type(self) -> str:
|
||||||
return self.data.device.type
|
return self.data.device.type
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(id(self))
|
||||||
|
|
||||||
|
def __eq__(self, __o: object) -> bool:
|
||||||
|
return self is __o
|
||||||
|
|
||||||
|
|
||||||
class ChunkManager:
|
class ChunkManager:
|
||||||
|
|
||||||
|
@ -226,8 +232,7 @@ class ChunkManager:
|
||||||
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
|
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
|
||||||
return src_rank
|
return src_rank
|
||||||
|
|
||||||
def access_chunk(self, tensor: torch.Tensor) -> None:
|
def access_chunk(self, chunk: Chunk) -> None:
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
|
||||||
if chunk in self.accessed_chunks:
|
if chunk in self.accessed_chunks:
|
||||||
return
|
return
|
||||||
if not chunk.is_free:
|
if not chunk.is_free:
|
||||||
|
@ -236,10 +241,9 @@ class ChunkManager:
|
||||||
self.accessed_chunks.add(chunk)
|
self.accessed_chunks.add(chunk)
|
||||||
self.total_mem[chunk.device_type] += chunk.mem
|
self.total_mem[chunk.device_type] += chunk.mem
|
||||||
|
|
||||||
def release_chunk(self, tensor: torch.Tensor) -> None:
|
def release_chunk(self, chunk: Chunk) -> None:
|
||||||
if not self.enable_distributed_storage:
|
if not self.enable_distributed_storage:
|
||||||
return
|
return
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
|
||||||
if chunk not in self.accessed_chunks:
|
if chunk not in self.accessed_chunks:
|
||||||
return
|
return
|
||||||
if chunk.can_release:
|
if chunk.can_release:
|
||||||
|
@ -248,8 +252,7 @@ class ChunkManager:
|
||||||
if chunk.is_free:
|
if chunk.is_free:
|
||||||
self.total_mem[chunk.device_type] -= chunk.mem
|
self.total_mem[chunk.device_type] -= chunk.mem
|
||||||
|
|
||||||
def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
|
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
|
||||||
if chunk.data.device == device:
|
if chunk.data.device == device:
|
||||||
return
|
return
|
||||||
if chunk.can_move_device and not chunk.is_free:
|
if chunk.can_move_device and not chunk.is_free:
|
||||||
|
@ -261,8 +264,7 @@ class ChunkManager:
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
chunk = self.tensor_chunk_map[tensor]
|
||||||
chunk.tensor_trans_state(tensor, state)
|
chunk.tensor_trans_state(tensor, state)
|
||||||
|
|
||||||
def reduce_chunk(self, tensor: torch.Tensor) -> bool:
|
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
|
||||||
if not chunk.can_reduce:
|
if not chunk.can_reduce:
|
||||||
return False
|
return False
|
||||||
self.total_mem[chunk.device_type] -= chunk.mem
|
self.total_mem[chunk.device_type] -= chunk.mem
|
||||||
|
@ -274,10 +276,6 @@ class ChunkManager:
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
chunk = self.tensor_chunk_map[tensor]
|
||||||
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||||
|
|
||||||
def is_chunk_free(self, tensor: torch.Tensor) -> bool:
|
|
||||||
chunk = self.tensor_chunk_map[tensor]
|
|
||||||
return chunk.is_free
|
|
||||||
|
|
||||||
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||||
return self.tensor_chunk_map[tensor]
|
return self.tensor_chunk_map[tensor]
|
||||||
|
|
||||||
|
@ -285,8 +283,8 @@ class ChunkManager:
|
||||||
self.lazy_release_tensors.extend(tensors)
|
self.lazy_release_tensors.extend(tensors)
|
||||||
|
|
||||||
def exec_lazy_release(self) -> None:
|
def exec_lazy_release(self) -> None:
|
||||||
for tensor in self.lazy_release_tensors:
|
for chunk in self.get_chunks(self.lazy_release_tensors):
|
||||||
self.release_chunk(tensor)
|
self.release_chunk(chunk)
|
||||||
self.lazy_release_tensors.clear()
|
self.lazy_release_tensors.clear()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
@ -340,3 +338,23 @@ class ChunkManager:
|
||||||
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
|
for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
|
||||||
if not dest_chunk.is_free:
|
if not dest_chunk.is_free:
|
||||||
dest_chunk.copy_(src_chunk)
|
dest_chunk.copy_(src_chunk)
|
||||||
|
|
||||||
|
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||||
|
chunks = []
|
||||||
|
for tensor in tensors:
|
||||||
|
chunk = self.get_chunk(tensor)
|
||||||
|
if chunk not in chunks:
|
||||||
|
chunks.append(chunk)
|
||||||
|
return tuple(chunks)
|
||||||
|
|
||||||
|
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||||
|
"""Add extern static tensor to chunk manager.
|
||||||
|
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||||
|
They are "static", which means their shape, dtype, device never change.
|
||||||
|
Thus, their memory usage never changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||||
|
"""
|
||||||
|
assert tensor not in self.tensor_chunk_map
|
||||||
|
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||||
|
|
|
@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
|
||||||
self._training_phase = TrainingPhase.FORWARD
|
self._training_phase = TrainingPhase.FORWARD
|
||||||
|
|
||||||
def pre_op(self, params):
|
def pre_op(self, params):
|
||||||
|
chunks = self._chunk_manager.get_chunks(params)
|
||||||
for p in params:
|
for p in params:
|
||||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||||
self._chunk_manager.exec_lazy_release()
|
self._chunk_manager.exec_lazy_release()
|
||||||
# TODO: evict chunks
|
# TODO: evict chunks
|
||||||
for p in params:
|
for chunk in chunks:
|
||||||
self._chunk_manager.access_chunk(p)
|
self._chunk_manager.access_chunk(chunk)
|
||||||
|
|
||||||
def post_op(self, params):
|
def post_op(self, params):
|
||||||
for p in params:
|
for p in params:
|
||||||
|
|
|
@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
def _update_params_ptr(self):
|
def _update_params_ptr(self):
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if not self.module.chunk_manager.is_chunk_free(p):
|
if not self.module.chunk_manager.get_chunk(p).is_free:
|
||||||
p.data = self.fp16_param_to_fp32_param[p]
|
p.data = self.fp16_param_to_fp32_param[p]
|
||||||
else:
|
else:
|
||||||
assert p.grad is None
|
assert p.grad is None
|
||||||
|
|
|
@ -32,7 +32,7 @@ HAS_TENSORS = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
|
TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
|
||||||
|
|
||||||
|
|
||||||
@parameterize('use_chunk', [False, True])
|
@parameterize('use_chunk', [False, True])
|
||||||
|
@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero):
|
||||||
rank = gpc.get_local_rank(ParallelMode.DATA)
|
rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
|
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
|
||||||
params = [torch.rand(32, 32) for _ in range(3)]
|
params = [torch.rand(8, 8) for _ in range(3)]
|
||||||
chunk_size = 2048 if use_chunk else None
|
chunk_size = 128 if use_chunk else None
|
||||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||||
assert chunk_manager.total_mem['cpu'] == 0
|
assert chunk_manager.total_mem['cpu'] == 0
|
||||||
assert chunk_manager.total_mem['cuda'] == 0
|
assert chunk_manager.total_mem['cuda'] == 0
|
||||||
|
@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero):
|
||||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||||
assert chunk_manager.total_mem['cpu'] == 0
|
assert chunk_manager.total_mem['cpu'] == 0
|
||||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
|
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
|
||||||
for p in params:
|
chunks = chunk_manager.get_chunks(params)
|
||||||
chunk_manager.access_chunk(p)
|
for chunk in chunks:
|
||||||
|
chunk_manager.access_chunk(chunk)
|
||||||
check_has_params(params, [True, True, True])
|
check_has_params(params, [True, True, True])
|
||||||
assert chunk_manager.total_mem['cpu'] == 0
|
assert chunk_manager.total_mem['cpu'] == 0
|
||||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
|
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
|
||||||
for p in params:
|
for chunk in chunks:
|
||||||
chunk_manager.release_chunk(p)
|
chunk_manager.release_chunk(chunk)
|
||||||
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
|
||||||
assert chunk_manager.total_mem['cpu'] == 0
|
assert chunk_manager.total_mem['cpu'] == 0
|
||||||
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||||
for p in params:
|
for chunk in chunks:
|
||||||
chunk_manager.move_chunk(p, torch.device('cpu'))
|
chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||||
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
|
||||||
assert chunk_manager.total_mem['cuda'] == 0
|
assert chunk_manager.total_mem['cuda'] == 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue