mirror of https://github.com/hpcaitech/ColossalAI
[NFC] update chunk manager API (#2119)
parent
e99edfcb51
commit
e5aa8333e4
|
@ -294,7 +294,7 @@ class Chunk:
|
||||||
self.chunk_temp = None
|
self.chunk_temp = None
|
||||||
|
|
||||||
self.__scatter()
|
self.__scatter()
|
||||||
# always gathered chunk does not have shard
|
# gathered chunk never have shard attribute
|
||||||
if self.keep_gathered:
|
if self.keep_gathered:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -17,13 +17,13 @@ class ChunkManager:
|
||||||
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None:
|
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||||
|
|
||||||
self.device = init_device or get_current_device()
|
self.device = init_device or get_current_device()
|
||||||
self.size_config: Dict[int, int] = dict()
|
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||||
self.kwargs_config = chunk_configuration
|
self.kwargs_config = chunk_configuration
|
||||||
for k, v in self.kwargs_config.items():
|
for k, v in self.kwargs_config.items():
|
||||||
self.size_config[k] = v.pop('chunk_size')
|
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
|
||||||
v['init_device'] = self.device
|
v['init_device'] = self.device
|
||||||
|
|
||||||
self.chunk_groups: Dict[str, Deque] = dict()
|
self.chunk_groups: Dict[str, Deque] = dict()
|
||||||
|
@ -32,26 +32,28 @@ class ChunkManager:
|
||||||
self.accessed_mem: int = 0
|
self.accessed_mem: int = 0
|
||||||
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
|
||||||
|
|
||||||
def append_tensor(self,
|
def register_tensor(self,
|
||||||
tensor: ColoTensor,
|
tensor: ColoTensor,
|
||||||
group_type: str,
|
group_type: str,
|
||||||
config_key: int,
|
config_key: int,
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
pin_memory: bool = False) -> None:
|
pin_memory: bool = False) -> None:
|
||||||
"""Append a tensor to a chunk.
|
"""
|
||||||
|
Register a tensor to the chunk manager.
|
||||||
|
Then, the tensor should be accessed by `get_chunks`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: the tensor appended to the chunk
|
tensor: the tensor appended to the chunk
|
||||||
group_type: the data type of the group
|
group_type: the data type of the group.
|
||||||
config_key: the key of the group's name, usually the size of the dp world
|
config_key: the key of the group's name, the size of the dp world
|
||||||
cpu_offload: if True, the chunk will be closed on CPU
|
cpu_offload: if True, the chunk will be closed on CPU
|
||||||
pin_memory: whether the chunk is pinned in the cpu memory
|
pin_memory: whether the chunk is pinned in the cpu memory
|
||||||
"""
|
"""
|
||||||
assert tensor not in self.tensor_chunk_map
|
assert tensor not in self.tensor_chunk_map
|
||||||
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager"
|
||||||
assert config_key in self.size_config
|
assert config_key in self.dp_degree_chunk_size_dict
|
||||||
|
|
||||||
chunk_size = self.size_config[config_key]
|
chunk_size = self.dp_degree_chunk_size_dict[config_key]
|
||||||
chunk_kwargs = self.kwargs_config[config_key]
|
chunk_kwargs = self.kwargs_config[config_key]
|
||||||
group_name = "{}_{}".format(group_type, config_key)
|
group_name = "{}_{}".format(group_type, config_key)
|
||||||
chunk_group = self.__get_chunk_group(group_name)
|
chunk_group = self.__get_chunk_group(group_name)
|
||||||
|
|
|
@ -83,7 +83,7 @@ def search_chunk_configuration(
|
||||||
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, int]: chunk config and its memory chunk waste in byte.
|
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
param_order = OrderedParamGenerator()
|
param_order = OrderedParamGenerator()
|
||||||
|
|
|
@ -228,12 +228,12 @@ class ZeroDDP(ColoDDP):
|
||||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||||
p.data = p.data.half()
|
p.data = p.data.half()
|
||||||
dp_world_size = p.process_group.dp_world_size()
|
dp_world_size = p.process_group.dp_world_size()
|
||||||
self.chunk_manager.append_tensor(tensor=p,
|
self.chunk_manager.register_tensor(tensor=p,
|
||||||
group_type='fp16_param',
|
group_type='fp16_param',
|
||||||
config_key=dp_world_size,
|
config_key=dp_world_size,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
self.chunk_manager.append_tensor(tensor=fp32_p,
|
self.chunk_manager.register_tensor(tensor=fp32_p,
|
||||||
group_type='fp32_param',
|
group_type='fp32_param',
|
||||||
config_key=dp_world_size,
|
config_key=dp_world_size,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
import torch
|
|
||||||
import colossalai
|
|
||||||
import pytest
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.gemini.chunk import ChunkManager
|
from colossalai.gemini.chunk import ChunkManager
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec
|
|
||||||
from tests.test_tensor.common_utils import debug_print
|
from tests.test_tensor.common_utils import debug_print
|
||||||
|
|
||||||
CUDA_MEM_0 = {False: 512, True: 1024}
|
CUDA_MEM_0 = {False: 512, True: 1024}
|
||||||
|
@ -29,7 +31,7 @@ def exam_chunk_memory(keep_gathered, pin_memory):
|
||||||
assert chunk_manager.total_mem['cuda'] == 0
|
assert chunk_manager.total_mem['cuda'] == 0
|
||||||
|
|
||||||
for p in params:
|
for p in params:
|
||||||
chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory)
|
chunk_manager.register_tensor(p, 'param', 2, pin_memory=pin_memory)
|
||||||
chunk_manager.close_all_groups()
|
chunk_manager.close_all_groups()
|
||||||
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory]
|
||||||
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered]
|
||||||
|
|
Loading…
Reference in New Issue