mirror of https://github.com/hpcaitech/ColossalAI
[NFC] update chunk manager API (#2119)
parent
e99edfcb51
commit
e5aa8333e4
|
@ -294,7 +294,7 @@ class Chunk:
|
|||
self.chunk_temp = None
|
||||
|
||||
self.__scatter()
|
||||
# always gathered chunk does not have shard
|
||||
# gathered chunk never have shard attribute
|
||||
if self.keep_gathered:
|
||||
return
|
||||
|
||||
|
|
|
@ -17,13 +17,13 @@ class ChunkManager:
|
|||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
|
||||
self.device = init_device or get_current_device()
|
||||
self.size_config: Dict[int, int] = dict()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
self.size_config[k] = v.pop('chunk_size')
|
||||
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||
v['init_device'] = self.device
|
||||
|
||||
self.chunk_groups: Dict[str, Deque] = dict()
|
||||
|
@ -32,26 +32,28 @@ class ChunkManager:
|
|||
self.accessed_mem: int = 0
|
||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||
|
||||
def append_tensor(self,
|
||||
tensor: ColoTensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""Append a tensor to a chunk.
|
||||
def register_tensor(self,
|
||||
tensor: ColoTensor,
|
||||
group_type: str,
|
||||
config_key: int,
|
||||
cpu_offload: bool = False,
|
||||
pin_memory: bool = False) -> None:
|
||||
"""
|
||||
Register a tensor to the chunk manager.
|
||||
Then, the tensor should be accessed by `get_chunks`.
|
||||
|
||||
Args:
|
||||
tensor: the tensor appended to the chunk
|
||||
group_type: the data type of the group
|
||||
config_key: the key of the group's name, usually the size of the dp world
|
||||
group_type: the data type of the group.
|
||||
config_key: the key of the group's name, the size of the dp world
|
||||
cpu_offload: if True, the chunk will be closed on CPU
|
||||
pin_memory: whether the chunk is pinned in the cpu memory
|
||||
"""
|
||||
assert tensor not in self.tensor_chunk_map
|
||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||
assert config_key in self.size_config
|
||||
assert config_key in self.dp_degree_chunk_size_dict
|
||||
|
||||
chunk_size = self.size_config[config_key]
|
||||
chunk_size = self.dp_degree_chunk_size_dict[config_key]
|
||||
chunk_kwargs = self.kwargs_config[config_key]
|
||||
group_name = "{}_{}".format(group_type, config_key)
|
||||
chunk_group = self.__get_chunk_group(group_name)
|
||||
|
|
|
@ -83,7 +83,7 @@ def search_chunk_configuration(
|
|||
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, int]: chunk config and its memory chunk waste in byte.
|
||||
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
||||
"""
|
||||
|
||||
param_order = OrderedParamGenerator()
|
||||
|
|
|
@ -228,16 +228,16 @@ class ZeroDDP(ColoDDP):
|
|||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
p.data = p.data.half()
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
self.chunk_manager.append_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.append_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=p,
|
||||
group_type='fp16_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||
group_type='fp32_param',
|
||||
config_key=dp_world_size,
|
||||
cpu_offload=cpu_offload,
|
||||
pin_memory=pin_memory)
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self.chunk_manager.close_all_groups()
|
||||
|
|
|
@ -1,70 +1,72 @@
|
|||
import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
CUDA_MEM_0 = {False: 512, True: 1024}
|
||||
CUDA_MEM_1 = {False: 0, True: 1024}
|
||||
CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
|
||||
|
||||
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_memory(keep_gathered, pin_memory):
|
||||
pg = ProcessGroup()
|
||||
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
|
||||
|
||||
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
|
||||
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
|
||||
|
||||
chunk_manager = ChunkManager(config)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
for p in params:
|
||||
chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory)
|
||||
chunk_manager.close_all_groups()
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
||||
chunks = chunk_manager.get_chunks(params)
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True]
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_chunk_memory()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_chunk_manager(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_manager(2)
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.gemini.chunk import ChunkManager
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_tensor.common_utils import debug_print
|
||||
|
||||
CUDA_MEM_0 = {False: 512, True: 1024}
|
||||
CUDA_MEM_1 = {False: 0, True: 1024}
|
||||
CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
|
||||
|
||||
|
||||
@parameterize('keep_gathered', [True, False])
|
||||
@parameterize('pin_memory', [True, False])
|
||||
def exam_chunk_memory(keep_gathered, pin_memory):
|
||||
pg = ProcessGroup()
|
||||
|
||||
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
|
||||
|
||||
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
|
||||
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
|
||||
|
||||
chunk_manager = ChunkManager(config)
|
||||
assert chunk_manager.total_mem['cpu'] == 0
|
||||
assert chunk_manager.total_mem['cuda'] == 0
|
||||
|
||||
for p in params:
|
||||
chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory)
|
||||
chunk_manager.close_all_groups()
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
||||
chunks = chunk_manager.get_chunks(params)
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_manager.access_chunk(chunk)
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True]
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True]
|
||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered]
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
exam_chunk_memory()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_chunk_manager(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_manager(2)
|
||||
|
|
Loading…
Reference in New Issue