mirror of https://github.com/hpcaitech/ColossalAI
[zero] add chunk_managerV2 for all-gather chunk (#1441)
parent
3b26516c69
commit
b80340168e
|
@ -1,2 +1,3 @@
|
||||||
from .chunkv2 import ChunkV2
|
from .chunkv2 import ChunkV2
|
||||||
|
from .chunk_mgrv2 import ChunkManagerV2
|
||||||
from .search_utils import clasify_params, search_chunk_configuration
|
from .search_utils import clasify_params, search_chunk_configuration
|
||||||
|
|
|
@ -0,0 +1,221 @@
|
||||||
|
import torch
|
||||||
|
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.tensor import ColoTensor
|
||||||
|
from colossalai.gemini.chunk import ChunkFullError, TensorState
|
||||||
|
from colossalai.gemini.update import ChunkV2 as Chunk
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkManagerV2:
|
||||||
|
"""
|
||||||
|
A manager class to manipulate the tensors in chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager.
|
||||||
|
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||||
|
pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chunk_configuration: Dict[int, Dict],
|
||||||
|
init_device: Optional[torch.device] = None,
|
||||||
|
pin_memory: bool = False) -> None:
|
||||||
|
|
||||||
|
self.device = init_device or get_current_device()
|
||||||
|
self.size_config: Dict[int, int] = dict()
|
||||||
|
self.kwargs_config = chunk_configuration
|
||||||
|
for k, v in self.kwargs_config.items():
|
||||||
|
self.size_config[k] = v.pop('chunk_size')
|
||||||
|
v['init_device'] = self.device
|
||||||
|
v['pin_memory'] = pin_memory
|
||||||
|
|
||||||
|
self.chunk_groups: Dict[str, Deque] = dict()
|
||||||
|
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
|
||||||
|
self.accessed_chunks: Set[Chunk] = set()
|
||||||
|
self.lazy_release_tensors: List[torch.Tensor] = list()
|
||||||
|
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||||
|
|
||||||
|
def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int) -> None:
|
||||||
|
"""Append a tensor to a chunk.
|
||||||
|
"""
|
||||||
|
assert tensor not in self.tensor_chunk_map
|
||||||
|
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||||
|
assert config_key in self.size_config
|
||||||
|
|
||||||
|
chunk_size = self.size_config[config_key]
|
||||||
|
chunk_kwargs = self.kwargs_config[config_key]
|
||||||
|
group_name = "{}_{}".format(group_type, config_key)
|
||||||
|
chunk_group = self.__get_chunk_group(group_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# append the tensor to the last chunk
|
||||||
|
chunk_group[-1].append_tensor(tensor)
|
||||||
|
except (IndexError, ChunkFullError):
|
||||||
|
# the except statement will be triggered when there is no chunk or
|
||||||
|
# the last chunk in the chunk group is full
|
||||||
|
# this will create a new chunk and allocate this chunk to its corresponding process
|
||||||
|
if chunk_group:
|
||||||
|
# the chunk group is not empty
|
||||||
|
# close the last chunk
|
||||||
|
self.__close_one_chunk(chunk_group[-1])
|
||||||
|
|
||||||
|
if tensor.numel() > chunk_size:
|
||||||
|
chunk_size = tensor.numel()
|
||||||
|
chunk = Chunk(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
process_group=tensor.process_group,
|
||||||
|
dtype=tensor.dtype,
|
||||||
|
**chunk_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_group.append(chunk)
|
||||||
|
chunk.append_tensor(tensor)
|
||||||
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
|
||||||
|
self.tensor_chunk_map[tensor] = chunk_group[-1]
|
||||||
|
|
||||||
|
def close_all_groups(self):
|
||||||
|
"""Close all the chunks of all groups.
|
||||||
|
"""
|
||||||
|
for group_name in self.chunk_groups:
|
||||||
|
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||||
|
|
||||||
|
def access_chunk(self, chunk: Chunk) -> None:
|
||||||
|
"""Make the chunk can be used for calculation.
|
||||||
|
"""
|
||||||
|
if chunk in self.accessed_chunks:
|
||||||
|
return
|
||||||
|
self.__sub_memroy_usage(chunk.memory_usage)
|
||||||
|
chunk.access_chunk()
|
||||||
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
self.accessed_chunks.add(chunk)
|
||||||
|
|
||||||
|
def release_chunk(self, chunk: Chunk) -> None:
|
||||||
|
"""Scatter the chunk in CUDA.
|
||||||
|
"""
|
||||||
|
if chunk not in self.accessed_chunks:
|
||||||
|
return
|
||||||
|
if chunk.can_release:
|
||||||
|
self.__sub_memroy_usage(chunk.memory_usage)
|
||||||
|
chunk.release_chunk()
|
||||||
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
self.accessed_chunks.remove(chunk)
|
||||||
|
|
||||||
|
def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
|
||||||
|
"""Move the shard of the chunk to the target device.
|
||||||
|
"""
|
||||||
|
if not chunk.can_move or chunk.device_type == device.type:
|
||||||
|
return
|
||||||
|
self.__sub_memroy_usage(chunk.memory_usage)
|
||||||
|
chunk.shard_move(device)
|
||||||
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
|
||||||
|
def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
|
||||||
|
"""Transit tensor state according to pre-defined state machine.
|
||||||
|
"""
|
||||||
|
chunk = self.tensor_chunk_map[tensor]
|
||||||
|
chunk.tensor_trans_state(tensor, state)
|
||||||
|
|
||||||
|
def reduce_chunk(self, chunk: Chunk) -> bool:
|
||||||
|
"""Reduce or all reduce the chunk.
|
||||||
|
"""
|
||||||
|
if not chunk.can_reduce:
|
||||||
|
return False
|
||||||
|
self.__sub_memroy_usage(chunk.memory_usage)
|
||||||
|
chunk.release_chunk()
|
||||||
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Copy data to the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||||
|
data (torch.Tensor): the tensor to be copied to the chunk
|
||||||
|
"""
|
||||||
|
chunk = self.tensor_chunk_map[tensor]
|
||||||
|
chunk.copy_tensor_to_chunk_slice(tensor, data)
|
||||||
|
|
||||||
|
def get_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||||
|
"""
|
||||||
|
Return the chunk owning the tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): a torch tensor object
|
||||||
|
"""
|
||||||
|
return self.tensor_chunk_map[tensor]
|
||||||
|
|
||||||
|
def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None:
|
||||||
|
"""
|
||||||
|
Add tensors to the buffer for lazy release.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors (List[torch.Tensor]): the tensors to be released lazily
|
||||||
|
"""
|
||||||
|
self.lazy_release_tensors.extend(tensors)
|
||||||
|
|
||||||
|
def exec_lazy_release(self) -> None:
|
||||||
|
"""
|
||||||
|
Execute release for tensors added to the lazy release buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for chunk in self.get_chunks(self.lazy_release_tensors):
|
||||||
|
self.release_chunk(chunk)
|
||||||
|
self.lazy_release_tensors.clear()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
msg = ['Chunk Manager Information:\n',
|
||||||
|
'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n']
|
||||||
|
for group_name, group in self.chunk_groups.items():
|
||||||
|
msg.append(f'Group {group_name}:\n')
|
||||||
|
for i, chunk in enumerate(group):
|
||||||
|
msg.append(f'[{i}] {chunk}\n')
|
||||||
|
return ''.join(msg)
|
||||||
|
|
||||||
|
def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
|
||||||
|
"""
|
||||||
|
Get all chunks owning the input tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors (Iterable[torch.Tensor]): the tensors used to look for chunks
|
||||||
|
"""
|
||||||
|
chunks = []
|
||||||
|
for tensor in tensors:
|
||||||
|
chunk = self.get_chunk(tensor)
|
||||||
|
if chunk not in chunks:
|
||||||
|
chunks.append(chunk)
|
||||||
|
return tuple(chunks)
|
||||||
|
|
||||||
|
def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
|
||||||
|
"""Add extern static tensor to chunk manager.
|
||||||
|
Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
|
||||||
|
They are "static", which means their shape, dtype, device never change.
|
||||||
|
Thus, their memory usage never changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
|
||||||
|
"""
|
||||||
|
assert tensor not in self.tensor_chunk_map
|
||||||
|
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
|
||||||
|
|
||||||
|
def __get_chunk_group(self, group_name: str) -> Deque:
|
||||||
|
"""Register a chunk group.
|
||||||
|
"""
|
||||||
|
if group_name not in self.chunk_groups:
|
||||||
|
self.chunk_groups[group_name] = deque()
|
||||||
|
return self.chunk_groups[group_name]
|
||||||
|
|
||||||
|
def __close_one_chunk(self, chunk: Chunk):
|
||||||
|
self.__sub_memroy_usage(chunk.memory_usage)
|
||||||
|
chunk.close_chunk(self.device)
|
||||||
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
|
||||||
|
def __sub_memroy_usage(self, usage: Dict[str, int]):
|
||||||
|
for k, v in usage.items():
|
||||||
|
self.total_mem[k] -= v
|
||||||
|
|
||||||
|
def __add_memory_usage(self, usage: Dict[str, int]):
|
||||||
|
for k, v in usage.items():
|
||||||
|
self.total_mem[k] += v
|
|
@ -0,0 +1,76 @@
|
||||||
|
import torch
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from functools import partial
|
||||||
|
from colossalai.gemini.update import ChunkManagerV2
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec
|
||||||
|
from tests.test_tensor.common_utils import debug_print
|
||||||
|
|
||||||
|
CUDA_MEM_0 = {False: 512, True: 1024}
|
||||||
|
CUDA_MEM_1 = {False: 0, True: 1024}
|
||||||
|
CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('keep_gathered', [True, False])
|
||||||
|
@parameterize('pin_memory', [True, False])
|
||||||
|
def exam_chunk_memory(keep_gathered, pin_memory):
|
||||||
|
pg = ProcessGroup()
|
||||||
|
|
||||||
|
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(
|
||||||
|
keep_gathered, pin_memory))
|
||||||
|
|
||||||
|
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
|
||||||
|
config = {
|
||||||
|
2: dict(
|
||||||
|
chunk_size=128,
|
||||||
|
keep_gathered=keep_gathered
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_manager = ChunkManagerV2(config, pin_memory=pin_memory)
|
||||||
|
assert chunk_manager.total_mem['cpu'] == 0
|
||||||
|
assert chunk_manager.total_mem['cuda'] == 0
|
||||||
|
|
||||||
|
for p in params:
|
||||||
|
chunk_manager.append_tensor(p, 'param', 2)
|
||||||
|
chunk_manager.close_all_groups()
|
||||||
|
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||||
|
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||||
|
|
||||||
|
chunks = chunk_manager.get_chunks(params)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_manager.access_chunk(chunk)
|
||||||
|
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||||
|
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_manager.release_chunk(chunk)
|
||||||
|
|
||||||
|
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||||
|
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||||
|
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True]
|
||||||
|
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered]
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
exam_chunk_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [2])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_chunk_manager(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_chunk_manager(2)
|
Loading…
Reference in New Issue