[NFC] update chunk manager API (#2119)

pull/2123/head
Jiarui Fang 2022-12-12 16:57:22 +08:00 committed by GitHub
parent e99edfcb51
commit e5aa8333e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 100 additions and 96 deletions

View File

@ -294,7 +294,7 @@ class Chunk:
self.chunk_temp = None self.chunk_temp = None
self.__scatter() self.__scatter()
# always gathered chunk does not have shard # gathered chunk never have shard attribute
if self.keep_gathered: if self.keep_gathered:
return return

View File

@ -17,13 +17,13 @@ class ChunkManager:
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
""" """
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None: def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
self.device = init_device or get_current_device() self.device = init_device or get_current_device()
self.size_config: Dict[int, int] = dict() self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration self.kwargs_config = chunk_configuration
for k, v in self.kwargs_config.items(): for k, v in self.kwargs_config.items():
self.size_config[k] = v.pop('chunk_size') self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict() self.chunk_groups: Dict[str, Deque] = dict()
@ -32,26 +32,28 @@ class ChunkManager:
self.accessed_mem: int = 0 self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def append_tensor(self, def register_tensor(self,
tensor: ColoTensor, tensor: ColoTensor,
group_type: str, group_type: str,
config_key: int, config_key: int,
cpu_offload: bool = False, cpu_offload: bool = False,
pin_memory: bool = False) -> None: pin_memory: bool = False) -> None:
"""Append a tensor to a chunk. """
Register a tensor to the chunk manager.
Then, the tensor should be accessed by `get_chunks`.
Args: Args:
tensor: the tensor appended to the chunk tensor: the tensor appended to the chunk
group_type: the data type of the group group_type: the data type of the group.
config_key: the key of the group's name, usually the size of the dp world config_key: the key of the group's name, the size of the dp world
cpu_offload: if True, the chunk will be closed on CPU cpu_offload: if True, the chunk will be closed on CPU
pin_memory: whether the chunk is pinned in the cpu memory pin_memory: whether the chunk is pinned in the cpu memory
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
assert config_key in self.size_config assert config_key in self.dp_degree_chunk_size_dict
chunk_size = self.size_config[config_key] chunk_size = self.dp_degree_chunk_size_dict[config_key]
chunk_kwargs = self.kwargs_config[config_key] chunk_kwargs = self.kwargs_config[config_key]
group_name = "{}_{}".format(group_type, config_key) group_name = "{}_{}".format(group_type, config_key)
chunk_group = self.__get_chunk_group(group_name) chunk_group = self.__get_chunk_group(group_name)

View File

@ -83,7 +83,7 @@ def search_chunk_configuration(
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True. filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
Returns: Returns:
Tuple[Dict, int]: chunk config and its memory chunk waste in byte. Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
""" """
param_order = OrderedParamGenerator() param_order = OrderedParamGenerator()

View File

@ -228,16 +228,16 @@ class ZeroDDP(ColoDDP):
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
p.data = p.data.half() p.data = p.data.half()
dp_world_size = p.process_group.dp_world_size() dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.append_tensor(tensor=p, self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param', group_type='fp16_param',
config_key=dp_world_size, config_key=dp_world_size,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
pin_memory=pin_memory) pin_memory=pin_memory)
self.chunk_manager.append_tensor(tensor=fp32_p, self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param', group_type='fp32_param',
config_key=dp_world_size, config_key=dp_world_size,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
pin_memory=pin_memory) pin_memory=pin_memory)
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()

View File

@ -1,70 +1,72 @@
import torch from functools import partial
import colossalai
import pytest import pytest
import torch.multiprocessing as mp import torch
from functools import partial import torch.multiprocessing as mp
from colossalai.gemini.chunk import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize import colossalai
from colossalai.utils import free_port from colossalai.gemini.chunk import ChunkManager
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from tests.test_tensor.common_utils import debug_print from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
CUDA_MEM_0 = {False: 512, True: 1024} from tests.test_tensor.common_utils import debug_print
CUDA_MEM_1 = {False: 0, True: 1024}
CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} CUDA_MEM_0 = {False: 512, True: 1024}
CUDA_MEM_1 = {False: 0, True: 1024}
CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}}
@parameterize('keep_gathered', [True, False])
@parameterize('pin_memory', [True, False])
def exam_chunk_memory(keep_gathered, pin_memory): @parameterize('keep_gathered', [True, False])
pg = ProcessGroup() @parameterize('pin_memory', [True, False])
def exam_chunk_memory(keep_gathered, pin_memory):
debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) pg = ProcessGroup()
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory))
config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)]
chunk_manager = ChunkManager(config) config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)}
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0 chunk_manager = ChunkManager(config)
assert chunk_manager.total_mem['cpu'] == 0
for p in params: assert chunk_manager.total_mem['cuda'] == 0
chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory)
chunk_manager.close_all_groups() for p in params:
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory)
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] chunk_manager.close_all_groups()
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
chunks = chunk_manager.get_chunks(params) assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
for chunk in chunks: chunks = chunk_manager.get_chunks(params)
chunk_manager.access_chunk(chunk)
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] for chunk in chunks:
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] chunk_manager.access_chunk(chunk)
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
for chunk in chunks: assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True]
chunk_manager.release_chunk(chunk)
for chunk in chunks:
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] chunk_manager.release_chunk(chunk)
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
for chunk in chunks: assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
chunk_manager.move_chunk(chunk, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] for chunk in chunks:
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] chunk_manager.move_chunk(chunk, torch.device('cpu'))
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True]
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered]
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_chunk_memory() def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_chunk_memory()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use() @pytest.mark.dist
def test_chunk_manager(world_size): @pytest.mark.parametrize('world_size', [2])
run_func = partial(run_dist, world_size=world_size, port=free_port()) @rerun_if_address_is_in_use()
mp.spawn(run_func, nprocs=world_size) def test_chunk_manager(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_manager(2)
if __name__ == '__main__':
test_chunk_manager(2)