diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py index 746b3e02a..a82640d67 100644 --- a/colossalai/gemini/__init__.py +++ b/colossalai/gemini/__init__.py @@ -1,10 +1,6 @@ -from .chunk import TensorInfo, Chunk, TensorState -from .chunk_mgr import ChunkManager +from .chunk import TensorInfo, TensorState from .stateful_tensor_mgr import StatefulTensorMgr from .tensor_placement_policy import TensorPlacementPolicyFactory from .gemini_mgr import GeminiManager -__all__ = [ - 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk', - 'TensorState' -] +__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState'] diff --git a/colossalai/gemini/chunk.py b/colossalai/gemini/chunk.py deleted file mode 100644 index b454fc988..000000000 --- a/colossalai/gemini/chunk.py +++ /dev/null @@ -1,316 +0,0 @@ -import torch -import torch.distributed as dist -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Dict, List - -from colossalai.utils import get_current_device -from colossalai.tensor import ProcessGroup as ColoProcessGroup - - -class TensorState(Enum): - FREE = 0 - COMPUTE = 1 - HOLD = 2 - HOLD_AFTER_BWD = 3 - READY_FOR_REDUCE = 4 - - -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), - (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), - (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) - - -@dataclass -class TensorInfo: - state: TensorState - offset: int - end: int - - -class ChunkFullError(Exception): - pass - - -def is_storage_empty(tensor: torch.Tensor) -> bool: - return tensor.storage().size() == 0 - - -def free_storage(tensor: torch.Tensor) -> None: - if not is_storage_empty(tensor): - tensor.storage().resize_(0) - - -def alloc_storage(tensor: torch.Tensor) -> None: - if is_storage_empty(tensor): - tensor.storage().resize_(tensor.numel()) - - -class Chunk: - """ - A chunk is a contiguous memory space which contains multiple tensors. - - Args: - chunk_size (int): the number of elements in a chunk - src_rank (int): the process which owns the chunk - dtype (torch.dtype): the data type of the chunk - init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU. - force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False. - """ - - def __init__(self, - chunk_size: int, - src_rank: int, - process_group: ColoProcessGroup, - dtype: torch.dtype, - init_device: Optional[torch.device] = None, - force_data_on_cuda: bool = False) -> None: - self.size = chunk_size - self.utilized_size = 0 - self.src_rank = src_rank - self.process_group = process_group - self.is_src_rank = process_group.dp_local_rank() == src_rank - self.global_src_rank = process_group.get_ranks_in_dp()[src_rank] - self.dtype = dtype - device = init_device or get_current_device() - if force_data_on_cuda: - self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device()) - self._cpu_data = torch.empty(chunk_size, dtype=dtype) - if device.type == 'cuda': - free_storage(self._cpu_data) - else: - free_storage(self.data) - else: - self.data = torch.empty(chunk_size, dtype=dtype, device=device) - self._cpu_data = None - - # we only keep the chunk in full in the process by which the tensor is owned - if not self.is_src_rank: - free_storage(self._payload) - - # each tensor is associated with a TensorInfo to track meta info - self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} - self.mem = self.size * self.data.element_size() - - def append(self, tensor: torch.Tensor) -> None: - """ - Add a tensor to the chunk. - - Args: - tensor (torch.Tensor): a tensor to be added to the chunk - """ - assert tensor.dtype == self.dtype - new_utilized_size = self.utilized_size + tensor.numel() - - # raise exception when the chunk size is exceeded - if new_utilized_size > self.size: - raise ChunkFullError - - # set tensor state - tensor_state = TensorState.FREE - - # if the process owns the rank, then copy the tensor to its chunk buffer - # otherwise set its storage size to 0 to reduce memory consumption - if self.is_src_rank: - self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten()) - tensor_state = TensorState.HOLD - assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) - else: - tensor.storage().resize_(0) - self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) - self.utilized_size = new_utilized_size - - def release(self) -> None: - """ - Release the memory space on processes which do not own the chunk. - """ - if not self.is_src_rank: - free_storage(self._payload) - self._update_tensors_state(TensorState.FREE) - - def _update_tensors_ptr(self) -> None: - assert type(self._payload) == torch.Tensor - for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) - - def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): - for tensor_info in self.tensors_info.values(): - if prev_state is None or tensor_info.state == prev_state: - tensor_info.state = next_state - - def access(self) -> None: - """ - Broadcast the chunk to synchronize the tensors across data parallel processes. - """ - # recover the chunk on non-owner processes - # and broadcast the chunk from the source to all processes - if not self.is_src_rank: - alloc_storage(self._payload) - self.move_device(get_current_device(), update_ptr=False) - dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) - - # update tensor meta info - self._update_tensors_ptr() - if not self.is_src_rank: - self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) - - def move_device(self, device: torch.device, update_ptr: bool = True) -> None: - """ - Move the chunk to a target device. - - Args: - device (torch.device): the target device for data movement. - """ - if self._payload.device == device: - return - if self._cpu_data is None: - self.data.data = self.data.to(device) - else: - if device.type == 'cuda': - # cpu -> cuda - src = self._cpu_data - dest = self.data - else: - # cuda -> cpu - src = self.data - dest = self._cpu_data - alloc_storage(dest) - dest.copy_(src) - free_storage(src) - - if update_ptr: - self._update_tensors_ptr() - - def reduce(self, is_all_reduce: bool = False) -> None: - """ - Reduce or all-reduce the chunk. - - Args: - is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false. - """ - self.move_device(get_current_device(), update_ptr=False) - if is_all_reduce: - dist.all_reduce(self.data, group=self.process_group.dp_process_group()) - else: - dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) - self._update_tensors_ptr() - self._update_tensors_state(TensorState.HOLD) - - def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: - """ - Make a transition of the tensor into the next state. - - Args: - tensor (torch.Tensor): a torch Tensor object. - tensor_state (TensorState): the target state for transition. - """ - - # As the gradient hook can be triggered either before or after post-backward - # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce - # or compute -> ready_for_reduce -> hold_after_bwd - # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd - # this function only apply valid state transformation - # invalid calls will be ignored and nothing changes - if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: - # print( - # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' - # ) - return - self.tensors_info[tensor].state = tensor_state - - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: - """ - Copy data slice to the memory space indexed by the input tensor in the chunk. - - Args: - tensor (torch.Tensor): the tensor used to retrive meta information - data_slice (torch.Tensor): the tensor to be copied to the chunk - """ - tensor_info = self.tensors_info[tensor] - self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) - tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) - - @property - def can_release(self) -> bool: - """ - Check whether the chunk can be released. - """ - for tensor_info in self.tensors_info.values(): - if tensor_info.state != TensorState.HOLD: - return False - return True - - @property - def can_move_device(self) -> bool: - """ - Check whether the chunk can be moved across devices. - """ - for tensor_info in self.tensors_info.values(): - if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): - return False - return True - - @property - def can_reduce(self) -> bool: - """ - Check whether the chunk can be reduced. - """ - for tensor_info in self.tensors_info.values(): - if tensor_info.state != TensorState.READY_FOR_REDUCE: - return False - return True - - @property - def is_empty(self) -> bool: - """ - Check whether the chunk is empty. - """ - return is_storage_empty(self._payload) - - def __repr__(self) -> str: - return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}' - - @property - def has_inf_or_nan(self) -> bool: - """ - Check if the chunk has inf or nan values. - """ - return torch.isinf(self._payload[:self.utilized_size]).any().item() or \ - torch.isnan(self._payload[:self.utilized_size]).any().item() - - def copy_(self, dest_chunk: 'Chunk'): - """ - Copy the data of this chunk to a destination chunk. - """ - assert not self.is_empty - assert not dest_chunk.is_empty - assert self.size == dest_chunk.size - assert self.utilized_size == dest_chunk.utilized_size - self._payload.copy_(dest_chunk._payload) - self._update_tensors_ptr() - - @property - def device_type(self) -> str: - """ - Get the device type of the chunk. - """ - return self._payload.device.type - - def __hash__(self) -> int: - return hash(id(self)) - - def __eq__(self, __o: object) -> bool: - return self is __o - - def get_tensors(self) -> List[torch.Tensor]: - return list(self.tensors_info.keys()) - - @property - def _payload(self) -> torch.Tensor: - if self._cpu_data is None or is_storage_empty(self._cpu_data): - return self.data - return self._cpu_data diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/gemini/chunk/__init__.py new file mode 100644 index 000000000..8468a6815 --- /dev/null +++ b/colossalai/gemini/chunk/__init__.py @@ -0,0 +1,3 @@ +from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk +from .manager import ChunkManager +from .search_utils import clasify_params, search_chunk_configuration diff --git a/colossalai/gemini/update/chunkv2.py b/colossalai/gemini/chunk/chunk.py similarity index 74% rename from colossalai/gemini/update/chunkv2.py rename to colossalai/gemini/chunk/chunk.py index 25f7858ea..e02d14055 100644 --- a/colossalai/gemini/update/chunkv2.py +++ b/colossalai/gemini/chunk/chunk.py @@ -1,14 +1,55 @@ import torch import torch.distributed as dist +from dataclasses import dataclass +from enum import Enum from typing import Optional, Dict, List from colossalai.utils import get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \ - free_storage, alloc_storage -class ChunkV2: +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +class Chunk: def __init__(self, chunk_size: int, @@ -19,18 +60,18 @@ class ChunkV2: pin_memory: bool = False) -> None: """ Chunk: A container owning a piece of contiguous memory space for tensors - AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk. - This kind of chunk is exclusively used for DDP and ZeRO DDP. + Here we use all-gather operation to gather the whole chunk. + Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. It is designed to make the full use of communication and PCIE bandwidth. Args: - chunk_size (int): the number of elements in a chunk + chunk_size (int): the number of elements in the chunk process_group (ColoProcessGroup): the process group of this chunk dtype (torch.dtype): the data type of the chunk init_device (torch.device): optional, the device where the tensor is initialized The default value is None, which is the current GPU keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory - pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory + pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory """ self.chunk_size = chunk_size @@ -42,7 +83,8 @@ class ChunkV2: self.pg_rank = dist.get_rank(self.torch_pg) # the chunk size should be able to be divied by the size of GPU - assert chunk_size % self.pg_size == 0 + if not keep_gathered: + assert chunk_size % self.pg_size == 0 self.shard_size = chunk_size // self.pg_size self.shard_begin = self.shard_size * self.pg_rank self.shard_end = self.shard_begin + self.shard_size @@ -80,18 +122,15 @@ class ChunkV2: # we introduce the paired chunk here # it refers to another chunk having the same parameters - # but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk + # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk self.paired_chunk = None - # if the the gradient of this chunk is reduced, the flag is True - # so the flag is False for unused parameters - self.grad_reduced_flag = False # if this chunk is synchronized with the optimizer, the flag is True self.optim_sync_flag = True # if the cpu_shard has been visited during the training step, the flag is True self.cpu_vis_flag = False @property - def memory_usage(self): + def memory_usage(self) -> Dict[str, int]: cuda_memory = 0 cpu_memory = 0 @@ -112,7 +151,7 @@ class ChunkV2: return dict(cuda=cuda_memory, cpu=cpu_memory) @property - def device_type(self): + def device_type(self) -> str: if self.chunk_temp is not None: return self.chunk_temp.device.type else: @@ -123,6 +162,56 @@ class ChunkV2: else: return 'cpu' + @property + def payload(self) -> torch.Tensor: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.chunk_total + elif self.cuda_shard is not None: + return self.cuda_shard + else: + return self.cpu_shard + + @property + def payload_mem(self) -> int: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.chunk_mem + else: + return self.shard_mem + + @property + def can_move(self) -> bool: + return not self.is_gathered + + @property + def can_release(self) -> bool: + if self.keep_gathered: + return False + else: + return self.tensors_state_monitor[TensorState.HOLD] + \ + self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors + + @property + def can_reduce(self): + return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors + + @property + def has_inf_or_nan(self) -> bool: + """Check if the chunk has inf or nan values in CUDA. + """ + if self.is_gathered: + valid_tensor = self.chunk_total[:self.utilized_size] + else: + assert self.cuda_shard is not None # only check in CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + + return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + def append_tensor(self, tensor: torch.Tensor): """Add a tensor to the chunk. @@ -150,7 +239,10 @@ class ChunkV2: self.utilized_size = new_utilized_size def close_chunk(self, shard_dev: Optional[torch.device] = None): - """Close the chunk. Any tensor can't be appended to a closed chunk. + """Close the chunk. Any tensor can't be appended to a closed chunk later. + + Args: + shard_dev: the device where the shard locates """ # sanity check assert self.chunk_temp is not None @@ -163,6 +255,7 @@ class ChunkV2: if self.chunk_temp.device.type == 'cpu': self.chunk_total = self.chunk_temp.to(get_current_device()) + self.__update_tensors_ptr() else: self.chunk_total = self.chunk_temp self.chunk_temp = None @@ -186,6 +279,12 @@ class ChunkV2: self.cuda_shard = None def shard_move(self, device: torch.device, force_copy: bool = False): + """Move the shard tensor in the chunk. + + Args: + device: the device to which the shard will move + force_copy: if True, copy function is called mandatorily + """ # sanity check assert not self.is_gathered # when the current chunk is not synchronized with the optimizer @@ -223,8 +322,7 @@ class ChunkV2: raise NotImplementedError def access_chunk(self): - """Make the chunk usable for the parameters inside it. - It is an operation done in CUDA. + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. """ # sanity check assert self.chunk_temp is None @@ -234,8 +332,7 @@ class ChunkV2: self.__update_tensors_ptr() def release_chunk(self): - """Release the usable chunk. - It is an operation done in CUDA. + """Release the usable chunk. It's an operation done in CUDA. """ # sanity check assert self.chunk_temp is None @@ -244,8 +341,7 @@ class ChunkV2: self.__scatter() def reduce(self): - """Reduce scatter all the gradients. - It is an operation done in CUDA. + """Reduce scatter all the gradients. It's an operation done in CUDA. """ # sanity check assert self.is_gathered @@ -267,7 +363,6 @@ class ChunkV2: free_storage(self.chunk_total) self.is_gathered = False self.__update_tensors_state(TensorState.HOLD) - self.grad_reduced_flag = True def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: """ @@ -285,9 +380,6 @@ class ChunkV2: # this function only apply valid state transformation # invalid calls will be ignored and nothing changes if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: - # print( - # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' - # ) return self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) @@ -306,46 +398,58 @@ class ChunkV2: self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) - @property - def can_move(self) -> bool: - return not self.is_gathered - - @property - def can_release(self) -> bool: + def get_valid_length(self) -> int: + """Get the valid length of the chunk's payload. + """ if self.keep_gathered: - return False + return self.utilized_size else: - return self.tensors_state_monitor[TensorState.HOLD] + \ - self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors + return self.valid_end - @property - def can_reduce(self): - return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors - - @property - def has_inf_or_nan(self) -> bool: + def init_pair(self, friend_chunk: 'Chunk') -> None: + """Initialize the paired chunk. """ - Check if the chunk has inf or nan values in CUDA. - """ - if self.is_gathered: - valid_tensor = self.chunk_total[:self.utilized_size] + if self.paired_chunk is None and friend_chunk.paired_chunk is None: + self.paired_chunk = friend_chunk + friend_chunk.paired_chunk = self else: - assert self.cuda_shard is not None # only check in CUDA - valid_tensor = self.cuda_shard[:self.valid_end] + assert self.paired_chunk is friend_chunk + assert friend_chunk.paired_chunk is self - return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + def optim_update(self) -> None: + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. + """ + # sanity check + assert self.paired_chunk is not None + + friend_chunk = self.paired_chunk + if self.is_gathered is True: + assert friend_chunk.is_gathered is True + self.chunk_total.copy_(friend_chunk.chunk_total) + self.optim_sync_flag = True + elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + self.cuda_shard.copy_(friend_chunk.cuda_shard) + self.optim_sync_flag = True + self.cpu_vis_flag = False + else: + # optim_sync_flag is set to False + # see shard_move function for more details + assert friend_chunk.device_type == 'cpu' + assert self.device_type == 'cpu' + self.optim_sync_flag = False + self.cpu_vis_flag = False + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) def __gather(self): if not self.is_gathered: # sanity check assert self.cuda_shard is not None - if self.pg_size == 1: - self.chunk_total = self.cuda_shard - else: - alloc_storage(self.chunk_total) - gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + alloc_storage(self.chunk_total) + gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) + dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) self.cuda_shard = None self.is_gathered = True @@ -404,9 +508,9 @@ class ChunkV2: def __eq__(self, __o: object) -> bool: return self is __o - def __repr__(self, detailed: bool = False): + def __repr__(self, detailed: bool = True): output = [ - "AgChunk Information:\n", + "Chunk Information:\n", "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, self.pg_size), "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( @@ -442,6 +546,3 @@ class ChunkV2: output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st])) return ''.join(output) - - def get_tensors(self) -> List[torch.Tensor]: - return list(self.tensors_info.keys()) diff --git a/colossalai/gemini/update/chunk_mgrv2.py b/colossalai/gemini/chunk/manager.py similarity index 80% rename from colossalai/gemini/update/chunk_mgrv2.py rename to colossalai/gemini/chunk/manager.py index d6cd0745c..2d75dcce5 100644 --- a/colossalai/gemini/update/chunk_mgrv2.py +++ b/colossalai/gemini/chunk/manager.py @@ -4,23 +4,19 @@ from collections import deque from colossalai.utils import get_current_device from colossalai.tensor import ColoTensor -from colossalai.gemini.chunk import ChunkFullError, TensorState -from colossalai.gemini.update import ChunkV2 as Chunk +from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk -class ChunkManagerV2: +class ChunkManager: """ A manager class to manipulate the tensors in chunks. Args: chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. - pin_memory (bool): if ture, all chunks have a piece of pinned memory in CPU. """ - def __init__(self, chunk_configuration: Dict[int, Dict], - init_device: Optional[torch.device] = None, - pin_memory: bool = False) -> None: + def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None: self.device = init_device or get_current_device() self.size_config: Dict[int, int] = dict() @@ -28,7 +24,6 @@ class ChunkManagerV2: for k, v in self.kwargs_config.items(): self.size_config[k] = v.pop('chunk_size') v['init_device'] = self.device - v['pin_memory'] = pin_memory self.chunk_groups: Dict[str, Deque] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() @@ -36,8 +31,14 @@ class ChunkManagerV2: self.lazy_release_tensors: List[torch.Tensor] = list() self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int) -> None: + def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None: """Append a tensor to a chunk. + + Args: + tensor: the tensor appended to the chunk + group_type: the data type of the group + config_key: the key of the group's name, usually the size of the dp world + pin_memory: whether the chunk is pinned in the cpu memory """ assert tensor not in self.tensor_chunk_map assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" @@ -66,7 +67,8 @@ class ChunkManagerV2: chunk_size=chunk_size, process_group=tensor.process_group, dtype=tensor.dtype, - **chunk_kwargs + pin_memory=pin_memory, + **chunk_kwargs, ) chunk_group.append(chunk) @@ -87,6 +89,8 @@ class ChunkManagerV2: if chunk in self.accessed_chunks: return self.__sub_memroy_usage(chunk.memory_usage) + if chunk.device_type == 'cpu': + chunk.shard_move(get_current_device()) chunk.access_chunk() self.__add_memory_usage(chunk.memory_usage) self.accessed_chunks.add(chunk) @@ -102,13 +106,13 @@ class ChunkManagerV2: self.__add_memory_usage(chunk.memory_usage) self.accessed_chunks.remove(chunk) - def move_chunk(self, chunk: Chunk, device: torch.device) -> None: + def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: """Move the shard of the chunk to the target device. """ if not chunk.can_move or chunk.device_type == device.type: return self.__sub_memroy_usage(chunk.memory_usage) - chunk.shard_move(device) + chunk.shard_move(device, force_copy) self.__add_memory_usage(chunk.memory_usage) def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: @@ -123,7 +127,7 @@ class ChunkManagerV2: if not chunk.can_reduce: return False self.__sub_memroy_usage(chunk.memory_usage) - chunk.release_chunk() + chunk.reduce() self.__add_memory_usage(chunk.memory_usage) return True @@ -165,14 +169,14 @@ class ChunkManagerV2: self.release_chunk(chunk) self.lazy_release_tensors.clear() - def __repr__(self) -> str: - msg = ['Chunk Manager Information:\n', - 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'] - for group_name, group in self.chunk_groups.items(): - msg.append(f'Group {group_name}:\n') - for i, chunk in enumerate(group): - msg.append(f'[{i}] {chunk}\n') - return ''.join(msg) + def get_cuda_movable_chunks(self, group_type: str) -> List[Chunk]: + chunk_list = [] + for group_name in self.chunk_groups: + if group_type in group_name: + for chunk in self.chunk_groups[group_name]: + if chunk.device_type == 'cuda' and chunk.can_move: + chunk_list.append(chunk) + return chunk_list def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: """ @@ -200,6 +204,17 @@ class ChunkManagerV2: assert tensor not in self.tensor_chunk_map self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + def __repr__(self) -> str: + msg = [ + 'Chunk Manager Information:\n', + 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + ] + for group_name, group in self.chunk_groups.items(): + msg.append(f'Group {group_name}:\n') + for i, chunk in enumerate(group): + msg.append(f'[{i}] {chunk}\n') + return ''.join(msg) + def __get_chunk_group(self, group_name: str) -> Deque: """Register a chunk group. """ @@ -208,8 +223,9 @@ class ChunkManagerV2: return self.chunk_groups[group_name] def __close_one_chunk(self, chunk: Chunk): + device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda self.__sub_memroy_usage(chunk.memory_usage) - chunk.close_chunk(self.device) + chunk.close_chunk(device) self.__add_memory_usage(chunk.memory_usage) def __sub_memroy_usage(self, usage: Dict[str, int]): diff --git a/colossalai/gemini/update/search_utils.py b/colossalai/gemini/chunk/search_utils.py similarity index 77% rename from colossalai/gemini/update/search_utils.py rename to colossalai/gemini/chunk/search_utils.py index fdbbf0817..f309872a4 100644 --- a/colossalai/gemini/update/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -1,3 +1,4 @@ +import math from typing import Dict, List import numpy as np import torch.nn as nn @@ -7,7 +8,7 @@ from colossalai.tensor import ColoParameter def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: """Filter those parameters whose size is too large from others. """ - params_size = [p.numel() for p in model.parameters()] + params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)] params_size_arr = np.array(params_size) std = np.std(params_size_arr) @@ -36,6 +37,9 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: params_dict: Dict[int, List[ColoParameter]] = dict() for param in model.parameters(): assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" + if getattr(param, '_ddp_to_ignore', False): + continue + param_key = param.process_group.dp_world_size() if param_key not in params_dict: @@ -47,13 +51,13 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: def search_chunk_configuration( model: nn.Module, - search_range_mb: int, + search_range_mb: float, search_interval_byte: int, # hidden size is the best value for the interval - min_chunk_size_mb: int = 32, - filter_exlarge_params: bool = True): - search_range_byte = search_range_mb * 1024**2 - min_chunk_size_byte = min_chunk_size_mb * 1024**2 - assert search_range_byte % search_interval_byte == 0 + min_chunk_size_mb: float = 32, + filter_exlarge_params: bool = True) -> Dict: + search_range_byte = round(search_range_mb * 1024**2) + min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) + assert search_range_byte >= 0 params_dict = clasify_params(model) config_dict: Dict[int, Dict] = dict() @@ -75,11 +79,12 @@ def search_chunk_configuration( max_size = min_chunk_size_byte for key in size_dict: max_size = max(max_size, max(size_dict[key])) + start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) min_chunk_waste = float('+inf') - best_chunk_size = max_size + best_chunk_size = start_size - for chunk_size in range(max_size, max_size + search_range_byte + 1, search_interval_byte): + for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): temp_waste = 0 for key in size_dict: temp_waste += _get_unused_byte(size_dict[key], chunk_size) diff --git a/colossalai/gemini/chunk_mgr.py b/colossalai/gemini/chunk_mgr.py deleted file mode 100644 index 4e236e5cd..000000000 --- a/colossalai/gemini/chunk_mgr.py +++ /dev/null @@ -1,344 +0,0 @@ -import torch -import numpy as np -from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable -from collections import deque - -from colossalai.utils import get_current_device -from colossalai.tensor import ProcessGroup as ColoProcessGroup, ColoTensor -from .chunk import Chunk, ChunkFullError, TensorState - - -class ChunkManager: - """ - A manager class to manipulate the tensors in chunks. - - Args: - chunk_size (int): the size of a chunk. - process_group (ColoProcessGroup): process group of the chunk. - enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false. - init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. - """ - - def __init__(self, - chunk_size: Optional[int], - process_group: ColoProcessGroup, - enable_distributed_storage: bool = False, - init_device: Optional[torch.device] = None) -> None: - assert chunk_size is None or chunk_size > 0 - assert isinstance(process_group, ColoProcessGroup) - self.chunk_size = chunk_size - self.process_group = process_group - self.enable_distributed_storage = enable_distributed_storage - self.device = init_device or get_current_device() - self.chunk_groups: Dict[str, Deque[Chunk]] = {} - self.groups_force_data_on_cuda: Dict[str, bool] = {} - self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {} - self.accessed_chunks: Set[Chunk] = set() - self.lazy_release_tensors: List[torch.Tensor] = [] - if enable_distributed_storage and chunk_size is None: - self.rank_load: Dict[str, torch.Tensor] = {} - self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - - def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None: - """Create a chunk group. - - Args: - group_name (str): group name - force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False. - """ - assert group_name not in self.chunk_groups - self.chunk_groups[group_name] = deque() - self.groups_force_data_on_cuda[group_name] = force_data_on_cuda - - def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: - """ - Append a tensor to a chunk. - - Args: - tensor (torch.Tensor): a tensor to append to the chunk. - group_name (str): the name of the chunk group. - """ - assert tensor not in self.tensor_chunk_map - if isinstance(tensor, ColoTensor): - assert tensor.get_process_group().dp_process_group() == self.process_group.dp_process_group( - ), f"Chunk Manager can only manage ColoTensor with the same DP process group" - try: - # append the tensor to the last chunk - self.chunk_groups[group_name][-1].append(tensor) - except (IndexError, ChunkFullError): - # the except statement will be triggered when there is no chunk or - # the last chunk in the chunk group is full - # this will create a new chunk and allocate this chunk to its corresponding process - if self.chunk_size is not None and tensor.numel() > self.chunk_size: - chunk_size = tensor.numel() - else: - chunk_size = self.chunk_size or tensor.numel() - src_rank = self._get_next_src_rank(group_name) - chunk = Chunk(chunk_size, - src_rank, - self.process_group, - tensor.dtype, - self.device, - force_data_on_cuda=self.groups_force_data_on_cuda[group_name]) - - if self.enable_distributed_storage and self.chunk_size is None: - self.rank_load[group_name][src_rank] += chunk_size - - self.chunk_groups[group_name].append(chunk) - chunk.append(tensor) - if not chunk.is_empty: - self.total_mem[chunk.device_type] += chunk.mem - self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1] - if not self.enable_distributed_storage: - # as distributed storage is not enabled, there is no need to broadcast - # chunks, thus we set these chunks as accessed - self.accessed_chunks.add(self.chunk_groups[group_name][-1]) - - def _get_next_src_rank(self, group_name: str) -> int: - if not self.enable_distributed_storage: - # the chunk is owned by the current rank if no distributed storage is enabled - return self.process_group.dp_local_rank() - if self.chunk_size is None: - if group_name not in self.rank_load: - self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64) - - # the process owning the tensor will be the process with the smallest number of elements - src_rank = torch.argmin(self.rank_load[group_name]).item() - else: - # chunk is owned by processes in a round-robin fashion - chunk_idx = len(self.chunk_groups[group_name]) - src_rank = chunk_idx % self.process_group.dp_world_size() - return src_rank - - def access_chunk(self, chunk: Chunk) -> None: - """ - Synchronize the chunks via broadcast. - - Args: - chunk (Chunk): the chunk to synchronize. - """ - if chunk in self.accessed_chunks: - if chunk.device_type != 'cuda': - self.total_mem[chunk.device_type] -= chunk.mem - chunk.move_device(get_current_device()) - self.total_mem[chunk.device_type] += chunk.mem - return - if not chunk.is_empty: - # as tensor is moved to the target device - # the memory consumption of the original device is reduced - self.total_mem[chunk.device_type] -= chunk.mem - chunk.access() - self.accessed_chunks.add(chunk) - self.total_mem[chunk.device_type] += chunk.mem - - def release_chunk(self, chunk: Chunk) -> None: - """ - Release the memory space of a chunk. - - Args: - chunk (Chunk): the chunk to release memory space - """ - - if not self.enable_distributed_storage: - return - if chunk not in self.accessed_chunks: - return - if chunk.can_release: - chunk.release() - self.accessed_chunks.remove(chunk) - if chunk.is_empty: - # update the memory consumption after releasing - self.total_mem[chunk.device_type] -= chunk.mem - - def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None: - """ - Move the chunk to the target device. - - Args: - chunk (Chunk): the chunk to move to target device - device (torch.device): target device - """ - if chunk.device_type == device.type: - return - if chunk.can_move_device and not chunk.is_empty: - self.total_mem[chunk.device_type] -= chunk.mem - chunk.move_device(device, update_ptr=update_ptr) - self.total_mem[chunk.device_type] += chunk.mem - - def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: - """ - Transit tensor state according to pre-defined state machine. - - Args: - tensor (torch.Tensor): the tensor for state transititon - state (TensorState): next tensor state for transtition - """ - chunk = self.tensor_chunk_map[tensor] - chunk.tensor_trans_state(tensor, state) - - def reduce_chunk(self, chunk: Chunk) -> bool: - """ - Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used. - Otherwise, this method uses reduce. - - Args: - chunk (Chunk): the chunk for reduction. - """ - if not chunk.can_reduce: - return False - self.total_mem[chunk.device_type] -= chunk.mem - chunk.reduce(is_all_reduce=not self.enable_distributed_storage) - self.total_mem[chunk.device_type] += chunk.mem - return True - - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: - """ - Copy data to the chunk. - - Args: - tensor (torch.Tensor): the tensor used to retrive meta information - data (torch.Tensor): the tensor to be copied to the chunk - """ - chunk = self.tensor_chunk_map[tensor] - chunk.copy_tensor_to_chunk_slice(tensor, data) - - def get_chunk(self, tensor: torch.Tensor) -> Chunk: - """ - Return the chunk owning the tensor. - - Args: - tensor (torch.Tensor): a torch tensor object - """ - return self.tensor_chunk_map[tensor] - - def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None: - """ - Add tensors to the buffer for lazy release. - - Args: - tensors (List[torch.Tensor]): the tensors to be released lazily - """ - self.lazy_release_tensors.extend(tensors) - - def exec_lazy_release(self) -> None: - """ - Execute release for tensors added to the lazy release buffer. - """ - - for chunk in self.get_chunks(self.lazy_release_tensors): - self.release_chunk(chunk) - self.lazy_release_tensors.clear() - - def __repr__(self) -> str: - msg = f'Rank {self.process_group.dp_local_rank()}:\n' - msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' - for group_name, group in self.chunk_groups.items(): - msg += f'Group {group_name}:\n' - for i, chunk in enumerate(group): - msg += f'[{i}] {chunk}\n' - return msg - - @staticmethod - def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float: - """ - Calculate the utilization rate of a chunk. - - Args: - chunk_size (int): the size of a chunk - params_numel (List[int]): the list of integers representing the number of elements of parameters - """ - assert len(params_numel) > 0 - total_size = 0 - total_utilized_size = 0 - cur_chunk_utilized_size = 0 - for size in params_numel: - assert chunk_size >= size - total_utilized_size += size - if total_size == 0 or cur_chunk_utilized_size + size > chunk_size: - total_size += chunk_size - cur_chunk_utilized_size = 0 - cur_chunk_utilized_size += size - return total_utilized_size / total_size - - @staticmethod - def search_chunk_size(module: torch.nn.Module, - search_range: int, - n_grids: int, - min_chunk_size: Optional[int] = None, - filter_exlarge_params: bool = True) -> int: - """ - Search for the chunk size for optimal chunk utilization. - - Args: - module (torch.nn.Module): a torch module object - search_range (int): the range of chunk size to search. The actual search range will be from - max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range. - n_grids (int): the number of intervals in the search range - min_chunk_size (int): optional, the minimum size for a chunk. The default is None. - - """ - assert search_range % n_grids == 0 - # TODO(ver217): sort params and filter unused ones - params_numel = [p.numel() for p in module.parameters()] - if filter_exlarge_params: - params_numel = _filter_exlarge_params(params_numel) - max_param_numel = max(params_numel) - if min_chunk_size is not None: - assert min_chunk_size >= max_param_numel - else: - min_chunk_size = max_param_numel - step_size = search_range // n_grids - max_chunk_util = -1 - best_chunk_size = -1 - for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size): - chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel) - if chunk_util > max_chunk_util: - max_chunk_util = chunk_util - best_chunk_size = chunk_size - return best_chunk_size - - def copy_chunk_group(self, dest_group_name: str, src_group_name: str): - """ - Copy chunk data from one group to another group. - - Args: - dest_group_name (str): the destination group which receives the copied data - src_group_name (str): the source group which provides the data to copy - """ - for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): - if not dest_chunk.is_empty: - dest_chunk.copy_(src_chunk) - - def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: - """ - Get all chunks owning the input tensors. - - Args: - tensors (Iterable[torch.Tensor]): the tensors used to look for chunks - """ - chunks = [] - for tensor in tensors: - chunk = self.get_chunk(tensor) - if chunk not in chunks: - chunks.append(chunk) - return tuple(chunks) - - def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: - """Add extern static tensor to chunk manager. - Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. - They are "static", which means their shape, dtype, device never change. - Thus, their memory usage never changes. - - Args: - tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. - """ - assert tensor not in self.tensor_chunk_map - self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() - - -def _filter_exlarge_params(params_numel: List[int]) -> List[int]: - params_numel_arr = np.array(params_numel) - std = np.std(params_numel_arr) - mean = np.mean(params_numel_arr) - upper_limit = mean + 3 * std - return list(filter(lambda x: x <= upper_limit, params_numel)) diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 4717e6f24..0bdddd9a7 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -3,7 +3,7 @@ import functools from .memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import List, Optional, Tuple from time import time -from colossalai.gemini import Chunk, ChunkManager +from colossalai.gemini.chunk import Chunk, ChunkManager from .placement_policy import PlacementPolicyFactory @@ -56,37 +56,44 @@ class GeminiManager: self._evict_time = 0 self._comp_cuda_demand_time = 0 - def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], group_type: str) -> None: """ Adjust the layout of statefuil tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() self._record_chunks_order(chunks) - cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name) + cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_type) self._layout_time += time() - start - vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list, + + vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, cuda_demand=cuda_demand, warmup=self._warmup, compute_list=self._compute_list, compute_idx=self._compute_idx) + self._d2h_volume += vol self._evict_time += evict_time # move COMPUTE tensors to CUDA self._h2d_volume += cuda_demand @functools.lru_cache(maxsize=None) - def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str): + def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_type: str): start = time() cuda_demand = 0 for chunk in chunks: - if chunk.device_type == 'cpu' or chunk.is_empty: - cuda_demand += chunk.mem + if chunk.device_type == 'cuda': + if chunk.is_gathered: + pass + else: + cuda_demand += chunk.chunk_mem - chunk.shard_mem + elif chunk.device_type == 'cpu': + cuda_demand += chunk.chunk_mem + else: + raise RuntimeError self._comp_cuda_demand_time += time() - start - can_evict_chunks = [] - for chunk in self._chunk_manager.chunk_groups[group_name]: - if not chunk.is_empty and chunk.device_type == 'cuda' and chunk.can_move_device: - can_evict_chunks.append(chunk) + + can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks(group_type) return cuda_demand, can_evict_chunks def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index 1ff88bd3f..4366956fe 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity from colossalai.utils import get_current_device from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini import ChunkManager +from colossalai.gemini.chunk import ChunkManager import torch import time diff --git a/colossalai/gemini/placement_policy.py b/colossalai/gemini/placement_policy.py index ec6afbc07..1a7e172ed 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/gemini/placement_policy.py @@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2 from typing import Type import functools -from colossalai.gemini import Chunk, ChunkManager +from colossalai.gemini.chunk import Chunk, ChunkManager class PlacementPolicy(ABC): @@ -19,7 +19,7 @@ class PlacementPolicy(ABC): self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector @abstractmethod - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None: + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: raise NotImplementedError @staticmethod @@ -32,12 +32,12 @@ class CPUPlacementPolicy(PlacementPolicy): def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int: + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: volume = 0 start = time() for chunk in can_evict_chunks: - self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False) - volume += chunk.mem + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + volume += chunk.shard_mem return volume, time() - start @@ -47,7 +47,7 @@ class CUDAPlacementPolicy(PlacementPolicy): assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int: + def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: return 0, 0 @staticmethod @@ -59,7 +59,8 @@ class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase - # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() and AutoPlacementPolicy.set_steady_cuda_cap_ratio() + # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() + # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() _warmup_non_model_data_ratio: float = 0.8 _steady_cuda_cap_ratio: float = 0.9 @@ -70,14 +71,14 @@ class AutoPlacementPolicy(PlacementPolicy): can_evict_chunks: List[Chunk], cuda_demand: int = 0, warmup: bool = True, - compute_list: List[Tuple[Chunk, ...]] = [], + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, compute_idx: int = 0, - **kwargs) -> int: + **kwargs) -> Tuple[int, float]: """ Evict tensors from CUDA device. Args: - hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like + can_evict_chunks (List[StatefulTensor]): the list of tensors that can be evicted. cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0. warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True. compute_list (List[StatefulTensor], optional): TODO. Defaults to []. @@ -114,12 +115,12 @@ class AutoPlacementPolicy(PlacementPolicy): for chunk in to_free_chunks: if freed_cuda_model_data >= to_free_cuda_model_data: break - freed_cuda_model_data += chunk.mem - self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False) + + self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + freed_cuda_model_data += chunk.shard_mem if freed_cuda_model_data < to_free_cuda_model_data: - raise RuntimeError( - f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" - ) + raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") return freed_cuda_model_data, time() - start @staticmethod @@ -147,7 +148,7 @@ class AutoPlacementPolicy(PlacementPolicy): class PlacementPolicyFactory: - policies: Dict[str, PlacementPolicy] = { + policies: Dict[str, Type[PlacementPolicy]] = { 'cpu': CPUPlacementPolicy, 'cuda': CUDAPlacementPolicy, 'auto': AutoPlacementPolicy diff --git a/colossalai/gemini/stateful_tensor_container.py b/colossalai/gemini/stateful_tensor_container.py deleted file mode 100644 index c82113028..000000000 --- a/colossalai/gemini/stateful_tensor_container.py +++ /dev/null @@ -1,131 +0,0 @@ -import queue -import heapq -from abc import ABC, abstractmethod -from typing import Optional, List, Dict -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState - - -def evict_check(st: StatefulTensor) -> bool: - if st.state is not TensorState.COMPUTE and st.device.type == 'cuda': - return True - return False - - -# Here ST means Stateful Tensor -class BaseSTContainer(ABC): - """A type of container that store all potential stateful tensors which can be evicted from - CUDA. This kind of stateful tensor should satisfy two conditions. One is that it hasn't been - evicted, meaning the type of its device is CUDA, the other is that it isn't pinned in CUDA - memory, meaning its state isn't COMPUTE. - - This container should get a stateful tensor when it become HOLD_LIKE from COMPUTE. - And it pops stateful tensors in function, `evict_tensors`. - - In order to acquire an optimal eviction policy, users may need to offer computation step - index of each stateful tensor. So we can use a heap to maintain all potential evictable - statefule tensors. When poping, we can get the stateful tensor that used furthest in - current computation step. - """ - - def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int): - self.compute_step_dict = compute_step_dict - self.total_step = total_step - - @abstractmethod - def empty(self) -> bool: - pass - - @abstractmethod - def create(self, stateful_tensor_list: List[StatefulTensor]) -> None: - pass - - @abstractmethod - def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None: - pass - - @abstractmethod - def pop(self) -> Optional[StatefulTensor]: - pass - - -class QueueSTContainer(BaseSTContainer): - """Queue type stateful tensor container. This is used in 'cpu' tensor placement policy. - It pops potential evictable stateful tensors in FIFO. - """ - - def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int): - super().__init__(compute_step_dict, total_step) - self.container = None - - def empty(self) -> bool: - assert self.container is not None - return self.container.empty() - - def create(self, stateful_tensor_list: List[StatefulTensor]) -> None: - self.container = queue.SimpleQueue() - for stateful_tensor in stateful_tensor_list: - self.container.put(stateful_tensor) - - def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None: - self.container.put(stateful_tensor) - - def pop(self) -> Optional[StatefulTensor]: - ret = None - while not self.empty(): - out_tensor = self.container.get() - if evict_check(out_tensor): - ret = out_tensor - break - - return ret - - -class HeapSTContainer(BaseSTContainer): - """Heap type stateful tensor container. This is used in 'auto' tensor placement policy. - It pops potential evictable stateful tensors in the order of the distance between current - step and next used step. - """ - - def __init__(self, compute_step_dict: Dict[StatefulTensor, List[int]], total_step: int): - super().__init__(compute_step_dict, total_step) - self.container = None - - def empty(self) -> bool: - assert self.container is not None - return self.container == [] - - def create(self, stateful_tensor_list: List[StatefulTensor]) -> None: - self.container = [] - for stateful_tensor in stateful_tensor_list: - # we want to pop the tensor which has the greatest next_step - # so the weight is next_step multiplied by -1 - weight = -self.__get_next_compute_step(stateful_tensor, -1) - self.container.append((weight, stateful_tensor)) - heapq.heapify(self.container) - - def push(self, stateful_tensor: StatefulTensor, cur_step: int) -> None: - # we want to pop the tensor which has the greatest next_step - # so the weight is next_step multiplied by -1 - weight = -self.__get_next_compute_step(stateful_tensor, cur_step) - heapq.heappush(self.container, (weight, stateful_tensor)) - - def pop(self) -> Optional[StatefulTensor]: - ret = None - while not self.empty(): - _, out_tensor = heapq.heappop(self.container) - if evict_check(out_tensor): - ret = out_tensor - break - return ret - - def __get_next_compute_step(self, stateful_tensor: StatefulTensor, cur_step: int): - # compute the id of next step - # if the tensor is not used in the furture - # next_step is set to the maximum - next_step = self.total_step - step_list = self.compute_step_dict[stateful_tensor] - for step in step_list: - if step > cur_step: - next_step = step - break - return next_step diff --git a/colossalai/gemini/update/__init__.py b/colossalai/gemini/update/__init__.py deleted file mode 100644 index 20e3abccb..000000000 --- a/colossalai/gemini/update/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .chunkv2 import ChunkV2 -from .chunk_mgrv2 import ChunkManagerV2 -from .search_utils import clasify_params, search_chunk_configuration diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 378f186a8..daa4cb15e 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -3,16 +3,18 @@ import itertools import torch.distributed as dist from functools import partial from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 -from colossalai.gemini.chunk import TensorState, Chunk from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.gemini_mgr import GeminiManager from typing import Dict, Iterable, List, Optional, Set from colossalai.logging import get_dist_logger from collections import OrderedDict -from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec from colossalai.tensor import ProcessGroup as ColoProcessGroup from .reducer import Reducer +from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager +from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: @@ -208,29 +210,41 @@ class ZeroDDP(ColoDDP): def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager, + pin_memory: bool = False, force_outputs_fp32: bool = False) -> None: - super().__init__(module, process_group=gemini_manager.chunk_manager.process_group) + super().__init__(module, process_group=ColoProcessGroup()) self.gemini_manager = gemini_manager - self.chunk_manager = gemini_manager.chunk_manager + self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = ZeROHookV2(gemini_manager) - self.fp32_params: List[ColoParameter] = [] + self.fp32_params: List[ColoTensor] = [] self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} - self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True) - self.chunk_manager.create_group('fp32_param') + # TODO: get param order and filter unused params for p in module.parameters(): + assert isinstance(p, ColoParameter) if getattr(p, '_ddp_to_ignore', False): p.data = p.half() continue - fp32_p = p.float().detach() + + dp_world_size = p.process_group.dp_world_size() + fp32_data = p.float().data p.data = p.half() - self.chunk_manager.append_tensor(p, 'fp16_param') - self.chunk_manager.append_tensor(fp32_p, 'fp32_param') + fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) + self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device + self.chunk_manager.close_all_groups() self._cast_buffers() + + params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] + for p, fp32_p in zip(params_list, self.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + chunk_32 = self.chunk_manager.get_chunk(fp32_p) + chunk_32.init_pair(chunk_16) + self._logger = get_dist_logger() def forward(self, *args, **kwargs): @@ -248,10 +262,7 @@ class ZeroDDP(ColoDDP): for p in self.module.parameters(): if getattr(p, '_ddp_to_ignore', False): continue - if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad: - p.grad = None - else: - p.grad = p.data + p.grad = None def _post_backward(self): self.chunk_manager.exec_lazy_release() @@ -276,21 +287,22 @@ class ZeroDDP(ColoDDP): free_storage(empty_grad) with torch._C.DisableTorchFunction(): self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) - if self.dp_world_size > 1: - grad = grad / self.dp_world_size - self.chunk_manager.copy_tensor_to_chunk_slice(p, grad) chunk = self.chunk_manager.get_chunk(p) + chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) - self.chunk_manager.release_chunk(chunk) - if reduced and not chunk.is_empty: + if reduced: + if chunk.is_gathered: + chunk.chunk_total.div_(chunk.pg_size) + else: + chunk.cuda_shard.div_(chunk.pg_size) self.overflow_counter += chunk.has_inf_or_nan - self.chunk_manager.move_chunk(chunk, self.grads_device[p]) + self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) - def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: + def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device @@ -311,14 +323,11 @@ class ZeroDDP(ColoDDP): ['bias', 'weight'] """ - is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 - record_flag = (not only_rank_0) or is_rank_0 - if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, record_flag) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) @@ -326,7 +335,7 @@ class ZeroDDP(ColoDDP): destination = hook_result return destination - def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True): + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.state_dict`. @@ -339,30 +348,30 @@ class ZeroDDP(ColoDDP): prefix (str): the prefix for parameters and buffers used in this module """ + assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + # save parameters param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(self.fp32_params) for chunk in chunk_list: - # record the original device of the chunk - org_chunk_dev_typ = chunk.device_type - self.chunk_manager.access_chunk(chunk) + temp_chunk = get_temp_total_chunk_on_cuda(chunk) - for tensor in chunk.get_tensors(): - rec_p = torch.empty([0]) + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) if record_flag: - rec_p = tensor.cpu() # move the whole tensor to CPU mem + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + assert tensor not in param_to_save_data - param_to_save_data[tensor] = rec_p - # release the actual memory of the chunk - self.chunk_manager.release_chunk(chunk) - if not chunk.is_empty and org_chunk_dev_typ == 'cpu': - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + param_to_save_data[tensor] = record_tensor + + del temp_chunk for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - rec_p = param_to_save_data[fp32_p] - destination[prefix + name] = rec_p if keep_vars else rec_p.detach() + record_parameter = param_to_save_data[fp32_p] + destination[prefix + name] = record_parameter # save all buffers for name, buf in self.named_buffers(): @@ -466,40 +475,61 @@ class ZeroDDP(ColoDDP): local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} - def load(name, dest_tensor, copy_func): - key = prefix + name - if key in state_dict: - input_param = state_dict[key] + def load(param_name, dest_tensor, copy_func): + state_key = prefix + param_name + if state_key in state_dict: + input_param = state_dict[state_key] # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] if input_param.shape != dest_tensor.shape: # local shape should match the one in checkpoint error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(key, input_param.shape, + 'the shape in current model is {}.'.format(state_key, input_param.shape, dest_tensor.shape)) return try: with torch.no_grad(): - # self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, input_param) copy_func(input_param) except Exception as ex: error_msgs.append('While copying the parameter named "{}", ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(key, dest_tensor.size(), input_param.size(), - ex.args)) + 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), + input_param.size(), ex.args)) elif strict: - missing_keys.append(key) + missing_keys.append(state_key) - def load_fp32_p(fp32_p, data): - if fp32_p.storage().size() > 0: - self.chunk_manager.copy_tensor_to_chunk_slice(fp32_p, data) + def load_fp32_parameter(chunk_slice, data): + chunk_slice.copy_(data.flatten()) + fp32_to_name = dict() for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): if p is not None: - load(name, fp32_p, partial(load_fp32_p, fp32_p)) - self.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param') + fp32_to_name[fp32_p] = name + + chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + parameter_name = fp32_to_name[tensor] + parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) + + if chunk.is_gathered: + chunk.chunk_total.copy_(temp_chunk) + elif chunk.cuda_shard is not None: + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + else: + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + + del temp_chunk + + for chunk_32 in chunk_list: + chunk_16 = chunk_32.paired_chunk + assert chunk_16 is not None + chunk_16.optim_update() for name, buf in persistent_buffers.items(): if buf is not None: diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py new file mode 100644 index 000000000..587339549 --- /dev/null +++ b/colossalai/nn/parallel/utils.py @@ -0,0 +1,20 @@ +import torch +import torch.distributed as dist +from colossalai.gemini.chunk import Chunk +from colossalai.utils import get_current_device + + +def get_temp_total_chunk_on_cuda(chunk: Chunk): + if chunk.is_gathered: + return chunk.chunk_total + + if chunk.cuda_shard is not None: + shard_temp = chunk.cuda_shard + else: + shard_temp = chunk.cpu_shard.to(get_current_device()) + + total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) + gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) + dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) + + return total_temp diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/zero_hook_v2.py index af0187b4f..3f3472f0e 100644 --- a/colossalai/zero/utils/zero_hook_v2.py +++ b/colossalai/zero/utils/zero_hook_v2.py @@ -54,8 +54,8 @@ class ZeROHookV2(ParamOpHook): @contextmanager def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase try: - old_training_phase = self._training_phase self._training_phase = training_phase yield finally: diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index 55b4d7ee9..aee8b2799 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -2,17 +2,14 @@ import torch import torch.distributed as dist from enum import Enum from torch.optim import Optimizer +from torch.nn import Parameter from colossalai.nn.parallel.data_parallel import ZeroDDP -from typing import Dict +from typing import Dict, Tuple, Set from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import get_current_device, disposable -from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm -from collections import defaultdict, abc as container_abcs -from copy import deepcopy -from itertools import chain -from torch._six import inf +from colossalai.gemini.chunk import Chunk, ChunkManager class OptimState(Enum): @@ -33,8 +30,8 @@ class ZeroOptimizer(ColossalaiOptimizer): Args: optim (Optimizer): An Optimizer instance. module (ZeroDDP): A ``ZeroDDP`` instance. - gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) - which will be used when using hybrid CPU optimizer. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". Defaults to 0.0. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. @@ -61,11 +58,19 @@ class ZeroOptimizer(ColossalaiOptimizer): assert isinstance(module, ZeroDDP) self.module = module self.gemini_manager = module.gemini_manager - self.chunk_manager = self.gemini_manager.chunk_manager + self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager self.optim_state = OptimState.UNSCALED - self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {} - for p, fp32_p in zip(module.parameters(), module.fp32_params): - self.fp16_param_to_fp32_param[p] = fp32_p + self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() + self.param_to_chunk32: Dict[Parameter, Chunk] = dict() + self.chunk16_set: Set[Chunk] = set() + + params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)] + for p, fp32_p in zip(params_list, module.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + if chunk_16 not in self.chunk16_set: + self.chunk16_set.add(chunk_16) + + self.__init__optimizer() # Grad scaler self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, @@ -75,7 +80,7 @@ class ZeroOptimizer(ColossalaiOptimizer): growth_interval=growth_interval, hysteresis=hysteresis, max_scale=max_scale) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device()) + self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) @@ -90,16 +95,26 @@ class ZeroOptimizer(ColossalaiOptimizer): self._register_states = disposable(self._register_states_) - def _update_params_ptr(self): - for group in self.optim.param_groups: - for p in group['params']: - if not self.module.chunk_manager.get_chunk(p).is_empty: - p.data = self.fp16_param_to_fp32_param[p] - else: - assert p.grad is None + def _set_grad_ptr(self): + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + begin, end = self.param_to_range[fake_param] + chunk16 = chunk32.paired_chunk + + fake_param.data = chunk16.payload[begin:end] + fake_param.grad = fake_param.data + fake_param.data = chunk32.payload[begin:end] def _update_fp16_params(self): - self.module.chunk_manager.copy_chunk_group('fp16_param', 'fp32_param') + none_tensor = torch.empty([0]) + for group in self.param_groups: + for fake_param in group['params']: + assert fake_param.grad is None + fake_param.data = none_tensor + + for chunk16 in self.chunk16_set: + chunk16.optim_update() def _check_overflow(self): # clear previous overflow record @@ -128,6 +143,7 @@ class ZeroOptimizer(ColossalaiOptimizer): def step(self, *args, **kwargs): self._maybe_move_fp32_params() + self._set_grad_ptr() # unscale grads if scaled if self.optim_state == OptimState.SCALED: self._unscale_grads() @@ -138,45 +154,14 @@ class ZeroOptimizer(ColossalaiOptimizer): self.zero_grad() self._update_fp16_params() return - self._update_params_ptr() ret = self.optim.step(*args, **kwargs) self._register_states() self.zero_grad() self._update_fp16_params() return ret - def compute_grad_norm(self, norm_type: float = 2.0) -> float: - norm_type = float(norm_type) - if not self.chunk_manager.enable_distributed_storage: - return compute_grad_norm(self.module.parameters(), norm_type) - - non_distributed_params = [] - distributed_params = [] - for p in self.module.parameters(): - if getattr(p, '_ddp_to_ignore', False): - non_distributed_params.append(p) - else: - distributed_params.append(p) - non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type) - distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)], - device=get_current_device()) - if norm_type == inf: - dist.all_reduce(distributed_norm_tensor, - op=dist.ReduceOp.MAX, - group=self.chunk_manager.process_group.dp_process_group()) - total_norm = max(non_distributed_norm, distributed_norm_tensor.item()) - else: - dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group()) - total_norm = non_distributed_norm + distributed_norm_tensor.item() - total_norm = total_norm**(1 / norm_type) - return total_norm - def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): - if self.optim_state == OptimState.SCALED: - self._unscale_grads() - total_norm = self.compute_grad_norm(norm_type) - _clip_grad_norm(self.module.parameters(), max_norm, total_norm) - return total_norm + raise NotImplementedError def backward(self, loss: torch.Tensor): loss = self.loss_scale * loss @@ -197,24 +182,31 @@ class ZeroOptimizer(ColossalaiOptimizer): available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param fp32_params_used_cuda_margin_mem = 0 - for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'], - self.chunk_manager.chunk_groups['fp32_param']): - if fp32_param_chunk.is_empty: - continue - if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem: - self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device()) - # stores grad now - self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device()) - self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device()) - fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem - for p in fp16_param_chunk.get_tensors(): - state = self.optim.state[p] + + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + chunk16 = chunk32.paired_chunk + + if chunk32.device_type == 'cuda': + continue + + if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: + self.chunk_manager.move_chunk(chunk32, get_current_device()) + # stores grad now + self.chunk_manager.move_chunk(chunk16, get_current_device()) + self.module.set_chunk_grad_device(chunk16, get_current_device()) + fp32_params_used_cuda_margin_mem += chunk32.payload_mem + + for group in self.param_groups: + for fake_param in group['params']: + chunk32 = self.param_to_chunk32[fake_param] + if chunk32.device_type == 'cuda': + state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(get_current_device()) - self.module._setup_grads_ptr() - def _register_states_(self): for group in self.optim.param_groups: for p in group['params']: @@ -223,110 +215,27 @@ class ZeroOptimizer(ColossalaiOptimizer): if isinstance(val, torch.Tensor): self.chunk_manager.add_extern_static_tensor(val) - def state_dict(self, only_rank_0: bool = True): - r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None. - This saves memory usage. + def __init__optimizer(self): - It contains two entries: + def get_range_pair(local_chunk: Chunk, local_param: Parameter): + param_info = local_chunk.tensors_info[local_param] + begin = max(0, param_info.offset - local_chunk.shard_begin) + end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) + return begin, end - * state - a dict holding current optimization state. Its content - differs between optimizer classes. - * param_groups - a list containing all parameter groups where each - parameter group is a dict - """ - is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 - if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0: - return - optim_state_dict = super().state_dict() - scaler_state_dict = self.grad_scaler.state_dict() - optim_state_dict['scaler'] = scaler_state_dict - if not self.chunk_manager.enable_distributed_storage: - return optim_state_dict - local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0} - if not self.chunk_manager.process_group.has_cpu_groups: - self.chunk_manager.process_group.set_cpu_groups() - output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())] - if only_rank_0: - dst_rank = self.chunk_manager.process_group.dp_rank_list()[0] - dist.gather_object(local_state, - output if self.chunk_manager.process_group.dp_local_rank() == 0 else None, - dst=dst_rank, - group=self.chunk_manager.process_group.cpu_dp_process_group()) - if not is_rank_0: - return - else: - dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group()) - for state in output: - optim_state_dict['state'].update(state) - return optim_state_dict + for group in self.optim.param_groups: + fake_params_list = list() - def load_state_dict(self, state_dict): - r"""Loads the optimizer state. + for param in group['params']: + chunk16 = self.chunk_manager.get_chunk(param) + range_pair = get_range_pair(chunk16, param) + if range_pair[0] >= range_pair[1]: + continue - Args: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict`. - """ - if 'scaler' not in state_dict: - self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) - else: - self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler'])) + fake_param = torch.nn.Parameter(torch.empty([0])) + self.param_to_chunk32[fake_param] = chunk16.paired_chunk + self.param_to_range[fake_param] = range_pair - # Validate the state_dict - groups = self.param_groups - saved_groups = deepcopy(state_dict['param_groups']) + fake_params_list.append(fake_param) - if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " - "parameter groups") - param_lens = (len(g['params']) for g in groups) - saved_lens = (len(g['params']) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") - - # Update the state - id_map = { - old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups - )), chain.from_iterable((g['params'] for g in groups))) - } - - def cast(param, value): - r"""Make a deep copy of value, casting all tensors to device of param.""" - if isinstance(value, torch.Tensor): - # Floating-point types are a bit special here. They are the only ones - # that are assumed to always match the type of params. - if param.is_floating_point(): - value = value.to(param.dtype) - value = value.to(param.device) - return value - elif isinstance(value, dict): - return {k: cast(param, v) for k, v in value.items()} - elif isinstance(value, container_abcs.Iterable): - return type(value)(cast(param, v) for v in value) - else: - return value - - # Copy state assigned to params (and cast tensors to appropriate types). - # State that is not assigned to params is copied as is (needed for - # backward compatibility). - state = defaultdict(dict) - for k, v in state_dict['state'].items(): - if k in id_map: - param = self.fp16_param_to_fp32_param[id_map[k]] - if param.storage().size() > 0: - state[param] = cast(param, deepcopy(v)) - else: - state[k] = deepcopy(v) - - # Update parameter groups, setting their 'params' value - def update_group(group, new_group): - new_group['params'] = group['params'] - return new_group - - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) - - -def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]): - return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()} + group['params'] = fake_params_list diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index 8789c18a6..d98018adf 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -6,11 +6,11 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini import ChunkManager +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from functools import partial from colossalai.nn.parallel import ColoDDP, ZeroDDP from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Callable +from typing import Callable, Type import torch.distributed as dist import os import random @@ -32,10 +32,9 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: return ColoDDP(module, process_group=pg) -def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP: - pg = ProcessGroup() - chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None - chunk_manager = ChunkManager(chunk_size, pg) +def init_ddpv2(module: torch.nn.Module) -> ZeroDDP: + chunk_config = search_chunk_configuration(module, 4, 1024) + chunk_manager = ChunkManager(chunk_config) gemini_manager = GeminiManager('cuda', chunk_manager) return ZeroDDP(module, gemini_manager) @@ -51,7 +50,7 @@ class Net(torch.nn.Module): return self.fc2(self.fc1(x)) -def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], ColoDDP]): +def run_fwd_bwd(ddp_cls: Type[ColoDDP], init_ddp_func: Callable[[torch.nn.Module], ColoDDP]): with ColoInitContext(device=get_current_device()): model = Net().cuda() w1 = model.fc1.weight @@ -62,8 +61,14 @@ def run_fwd_bwd(ddp_cls: ColoDDP, init_ddp_func: Callable[[torch.nn.Module], Col logits = model(x) loss = torch.sum(logits) model.backward(loss) + + if ddp_cls is ZeroDDP: + w1s_grad = w1 + else: + w1s_grad = w1.grad + w1_grads = [torch.empty_like(w1) for _ in range(dist.get_world_size())] - dist.all_gather(w1_grads, w1.grad) + dist.all_gather(w1_grads, w1s_grad) assert torch.equal(w1_grads[0], w1_grads[1]) w2_grads = [torch.empty_like(w2) for _ in range(dist.get_world_size())] dist.all_gather(w2_grads, w2.grad) @@ -74,8 +79,7 @@ def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') set_seed(dist.get_rank()) run_fwd_bwd(ColoDDP, init_ddp) - run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False)) - run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True)) + run_fwd_bwd(ZeroDDP, init_ddpv2) @pytest.mark.dist diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index c13f7a72c..f229364c6 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -8,14 +8,11 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini import ChunkManager from functools import partial from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ZeroDDP, ColoDDP -from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.nn.parallel import ColoDDP from collections import OrderedDict from colossalai.tensor import ProcessGroup, ColoParameter -from colossalai.testing import parameterize def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): @@ -30,42 +27,11 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2) -def check_model_equal(model_a, model_b, allow_empty: bool = False, same_dtype: bool = True): - for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()): - assert na == nb - - if not allow_empty: - assert pa.storage().size() > 0 - assert pb.storage().size() > 0 - else: - if pa.storage().size() == 0 or pb.storage().size() == 0: - continue - - if same_dtype: - assert pa.dtype == pb.dtype - temp_pb = pb - else: - temp_pb = pb.to(pa.dtype) - - assert torch.equal(pa, temp_pb), "Parameter '{}' is not equal.\n {} {}".format(na, pa, pb) - - def init_ddp(module: torch.nn.Module) -> ColoDDP: pg = ProcessGroup() return ColoDDP(module, process_group=pg) -def init_ddpv2(module: torch.nn.Module, - use_chunk: bool = False, - use_zero: bool = False, - placement_policy: str = 'cuda') -> ZeroDDP: - pg = ProcessGroup() - chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None - chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - return ZeroDDP(module, gemini_manager) - - def run_ddp_state_dict(): get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -88,44 +54,9 @@ def run_ddp_state_dict(): check_state_dict_equal(torch_state_dict, state_dict) -@parameterize('use_chunk', [False, True]) -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('use_zero', [False, True]) -@parameterize('only_rank_0', [False, True]) -def run_zero_state_dict(use_chunk, placement_policy, use_zero, only_rank_0): - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - torch_model = model_builder().cuda() - org_torch_model = copy.deepcopy(torch_model) - torch_state_dict = torch_model.state_dict() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - model = init_ddpv2(model, use_chunk, use_zero, placement_policy) - - for param in model.parameters(): - if isinstance(param, ColoParameter): - assert param.get_process_group() is not None - - model.load_state_dict(torch_state_dict, strict=False) - check_model_equal(model, torch_model, allow_empty=True, same_dtype=False) - - for param in model.parameters(): - if isinstance(param, ColoParameter): - assert param.get_process_group() is not None - - pg = ProcessGroup() - state_dict = model.state_dict(only_rank_0=only_rank_0) - if not only_rank_0 or pg.dp_local_rank() == 0: - torch_model.load_state_dict(state_dict, strict=False) - check_model_equal(torch_model, org_torch_model, allow_empty=False, same_dtype=True) - - def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_ddp_state_dict() - run_zero_state_dict() @pytest.mark.dist diff --git a/tests/test_gemini/test_stateful_tensor_container.py b/tests/test_gemini/test_stateful_tensor_container.py deleted file mode 100644 index 60ac2a69b..000000000 --- a/tests/test_gemini/test_stateful_tensor_container.py +++ /dev/null @@ -1,74 +0,0 @@ -import pytest -import torch - -from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor -from colossalai.gemini.stateful_tensor_container import QueueSTContainer, HeapSTContainer - - -@pytest.mark.dist -def test_stateful_tensor_container(): - st1 = StatefulTensor(torch.randn(1, device='cuda')) - st2 = StatefulTensor(torch.randn(2, device='cuda')) - st3 = StatefulTensor(torch.randn(3, device='cuda')) - stateful_tensor_list = [st1, st2, st3] - step_list = [st1, st2, st3, st3, st2, st1] - - compute_step_dict = dict() - compute_step_dict[st1] = [0, 5] - compute_step_dict[st2] = [1, 4] - compute_step_dict[st3] = [2, 3] - - def run_queue_test(): - # test queue container - queue_container = QueueSTContainer(compute_step_dict, 6) - queue_container.create(stateful_tensor_list) - - res_list = [] - - for i in range(6): - stateful_tensor = step_list[i] - stateful_tensor.trans_state(TensorState.COMPUTE) - st_out = queue_container.pop() - st_out.move_to(torch.device('cpu')) - - res_list.append(st_out.payload.size(0)) - - stateful_tensor.move_to(torch.device('cuda')) - queue_container.push(stateful_tensor, i) - stateful_tensor.trans_state(TensorState.HOLD) - - assert res_list == [2, 3, 1, 2, 3, 2] - - run_queue_test() - - def run_heap_test(): - # test heap container - st1.move_to(torch.device('cuda')) - st2.move_to(torch.device('cuda')) - st3.move_to(torch.device('cuda')) - - heap_container = HeapSTContainer(compute_step_dict, 6) - heap_container.create(stateful_tensor_list) - - res_list = [] - - for i in range(6): - stateful_tensor = step_list[i] - stateful_tensor.trans_state(TensorState.COMPUTE) - st_out = heap_container.pop() - - if st_out is not None: - res_list.append(st_out.payload.size(0)) - st_out.move_to(torch.device('cpu')) - - stateful_tensor.move_to(torch.device('cuda')) - heap_container.push(stateful_tensor, i) - stateful_tensor.trans_state(TensorState.HOLD) - - assert res_list == [3, 1, 2, 3, 2] - - run_heap_test() - - -if __name__ == '__main__': - test_stateful_tensor_container() diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_gemini/update/test_chunk_mgrv2.py index c4df217e1..fa7a9b1b5 100644 --- a/tests/test_gemini/update/test_chunk_mgrv2.py +++ b/tests/test_gemini/update/test_chunk_mgrv2.py @@ -3,7 +3,7 @@ import colossalai import pytest import torch.multiprocessing as mp from functools import partial -from colossalai.gemini.update import ChunkManagerV2 +from colossalai.gemini.chunk import ChunkManager from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.utils import free_port from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec @@ -19,23 +19,17 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} def exam_chunk_memory(keep_gathered, pin_memory): pg = ProcessGroup() - debug_print([0], "keep_gathered: {}, pin_memory: {}".format( - keep_gathered, pin_memory)) + debug_print([0], "keep_gathered: {}, pin_memory: {}".format(keep_gathered, pin_memory)) params = [ColoTensor(torch.rand(8, 8), spec=ColoTensorSpec(pg)) for _ in range(3)] - config = { - 2: dict( - chunk_size=128, - keep_gathered=keep_gathered - ) - } + config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} - chunk_manager = ChunkManagerV2(config, pin_memory=pin_memory) + chunk_manager = ChunkManager(config) assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cuda'] == 0 for p in params: - chunk_manager.append_tensor(p, 'param', 2) + chunk_manager.append_tensor(p, 'param', 2, pin_memory=pin_memory) chunk_manager.close_all_groups() assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py index deea46acb..57a49314f 100644 --- a/tests/test_gemini/update/test_chunkv2.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ColoParameter from colossalai.gemini import TensorState -from colossalai.gemini.update import ChunkV2 +from colossalai.gemini.chunk import Chunk def dist_sum(x): @@ -38,14 +38,12 @@ def check_euqal(param, param_cp): def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = ColoProcessGroup() - my_chunk = ChunkV2( - chunk_size=1024, - process_group=pg, - dtype=torch.float32, - init_device=init_device, - keep_gathered=keep_gathered, - pin_memory=pin_memory - ) + my_chunk = Chunk(chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + keep_gathered=keep_gathered, + pin_memory=pin_memory) param_list = [] param_cp_list = [] diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py new file mode 100644 index 000000000..6bd25c0be --- /dev/null +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -0,0 +1,109 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext + +from functools import partial +from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal +from tests.components_to_test.registry import non_distributed_component_funcs +from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.nn.parallel import ZeroDDP +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero import ZeroOptimizer +from colossalai.testing import parameterize +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor +from tests.test_tensor.common_utils import debug_print + +from time import time +from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager + + +def check_grad(model: ZeroDDP, torch_model: torch.nn.Module): + chunk_manager = model.chunk_manager + param_list = [p for p in model.parameters()] + chunk_list = chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + chunk_manager.access_chunk(chunk) + + for (p0, p1) in zip(model.parameters(), torch_model.parameters()): + assert torch.allclose(p0, p1.grad, atol=1e-3, rtol=1e-5), "{}".format(torch.max(torch.abs(p0 - p1.grad)).item()) + + +def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): + optimizer.zero_grad() + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + optimizer.backward(loss) + return logits + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +def exam_gpt_fwd_bwd(placement_policy): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + pg = ProcessGroup() + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) + + model.eval() + torch_model.eval() + + set_seed(pg.dp_local_rank()) + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + if i > 0: + break + + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + model.backward(loss) + + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + assert torch.allclose(logits, torch_logits, rtol=0), "{} {} {}".format( + torch.max(torch.abs(logits - torch_logits)).item(), logits, torch_logits) + + check_grad(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_gpt_fwd_bwd() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(1) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py new file mode 100644 index 000000000..cefda045d --- /dev/null +++ b/tests/test_gemini/update/test_optim.py @@ -0,0 +1,118 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext + +from functools import partial +from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal +from tests.components_to_test.registry import non_distributed_component_funcs +from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.nn.parallel import ZeroDDP +from colossalai.nn.optimizer import HybridAdam +from colossalai.zero import ZeroOptimizer +from colossalai.testing import parameterize +from colossalai.amp import convert_to_apex_amp +from colossalai.gemini.gemini_mgr import GeminiManager +from tests.test_tensor.common_utils import debug_print + +from time import time +from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager + + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert torch.allclose(value, temp_zero_value, rtol=1e-3, atol=1e-2), "parameter '{}' has problem.".format(key) + + +def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): + optimizer.zero_grad() + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + optimizer.backward(loss) + return logits + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +def exam_gpt_fwd_bwd(placement_policy): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + if i > 2: + break + + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + # debug_print([0], zero_logits, torch_logits) + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_gpt_fwd_bwd() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gpt(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_gpt(1) diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py index fcc6bcf0e..6655c3e39 100644 --- a/tests/test_gemini/update/test_search.py +++ b/tests/test_gemini/update/test_search.py @@ -8,7 +8,7 @@ import torch.distributed as dist import colossalai from colossalai.testing import rerun_if_address_is_in_use -from colossalai.gemini.update import search_chunk_configuration +from colossalai.gemini.chunk import search_chunk_configuration from colossalai.utils import free_port, get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup @@ -35,12 +35,11 @@ def exam_search_chunk_size(): with ColoInitContext(device=get_current_device()): model = model_builder() init_1d_row_spec(model, pg_tp) - config_dict = search_chunk_configuration( - model, - search_range_mb=1, - search_interval_byte=16, - min_chunk_size_mb=0, - filter_exlarge_params=True) + config_dict = search_chunk_configuration(model, + search_range_mb=1, + search_interval_byte=16, + min_chunk_size_mb=0, + filter_exlarge_params=True) for key in config_dict: chunk_size = config_dict[key]['chunk_size'] diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py new file mode 100644 index 000000000..86e39097e --- /dev/null +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -0,0 +1,111 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext + +from functools import partial +from tests.test_tensor.common_utils import set_seed +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.nn.parallel import ZeroDDP +from colossalai.zero import ZeroOptimizer +from colossalai.testing import parameterize +from colossalai.gemini.gemini_mgr import GeminiManager +from tests.test_tensor.common_utils import debug_print + +from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() + + for key, value in torch_dict.items(): + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_load_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + torch_dict = torch_model.state_dict() + model.load_state_dict(torch_dict, strict=False) + zero_dict = model.state_dict(only_rank_0=False) + + for key, value in torch_dict.items(): + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + exam_load_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_ddp(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_ddp(1) diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py new file mode 100644 index 000000000..9361c4b67 --- /dev/null +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -0,0 +1,97 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext + +from functools import partial +from tests.test_tensor.common_utils import set_seed +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.nn.parallel import ZeroDDP +from colossalai.zero import ZeroOptimizer +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize +from colossalai.gemini.gemini_mgr import GeminiManager +from tests.test_tensor.common_utils import debug_print + +from colossalai.gemini.chunk import search_chunk_configuration, ChunkManager + + +@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('keep_gathered', [True, False]) +def exam_zero_optim_state_dict(placement_policy, keep_gathered): + set_seed(431) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + set_seed(451) + torch_model = model_builder() # get a different model + + world_size = torch.distributed.get_world_size() + config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[world_size]['chunk_size'] = 5000 + config_dict[world_size]['keep_gathered'] = keep_gathered + + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters()) + optim = ZeroOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + + set_seed(dist.get_rank() * 3 + 128) + model.train() + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + if i > 0: + break + optim.zero_grad() + logits = model(input_ids, attn_mask) + logits = logits.float() + loss = criterion(logits, input_ids) + optim.backward(loss) + optim.step() + + optim_state_dict = optim.state_dict() + optim.load_state_dict(optim_state_dict) + new_state = optim.state_dict()['state'] + org_state = optim_state_dict['state'] + + for k, v in org_state.items(): + w = new_state[k] + for n, m in v.items(): + if isinstance(m, torch.Tensor): + o = w[n] + if m.device != o.device: + o = o.to(m.device) + assert torch.equal(m, o) + else: + assert m == w[n] + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_zero_optim_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_optim(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_zero_optim(1) diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py deleted file mode 100644 index 1f1b6e44b..000000000 --- a/tests/test_tensor/test_chunk.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -from typing import List -from functools import partial -from colossalai.gemini import ChunkManager -from colossalai.testing import rerun_if_address_is_in_use, parameterize -from colossalai.utils import free_port -from colossalai.tensor import ProcessGroup as ColoProcessGroup - - -def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]): - for p, has_tensor in zip(params, has_tensors): - if has_tensor: - assert p.storage().size() > 0 - assert p.device.type == 'cuda' - else: - assert p.storage().size() == 0 - - -# HAS_TENSORS[use_chunk][use_zero] -HAS_TENSORS = { - True: { - True: [[True, True, False], [False, False, True]], - False: [[True, True, True], [True, True, True]] - }, - False: { - True: [[True, False, True], [False, True, False]], - False: [[True, True, True], [True, True, True]] - } -} - -TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}} - - -@parameterize('use_chunk', [False, True]) -@parameterize('use_zero', [False, True]) -def run_chunk_zero(use_chunk, use_zero): - pg = ColoProcessGroup() - rank = pg.rank() - if rank == 0: - print(f'use_chunk={use_chunk}, use_zero={use_zero}') - params = [torch.rand(8, 8) for _ in range(3)] - chunk_size = 128 if use_chunk else None - chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) - chunk_manager.create_group('param') - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == 0 - for p in params: - chunk_manager.append_tensor(p, 'param') - check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank] - chunks = chunk_manager.get_chunks(params) - for chunk in chunks: - chunk_manager.access_chunk(chunk) - check_has_params(params, [True, True, True]) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank] - for chunk in chunks: - chunk_manager.release_chunk(chunk) - check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank]) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] - for chunk in chunks: - chunk_manager.move_chunk(chunk, torch.device('cpu')) - assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda'] - assert chunk_manager.total_mem['cuda'] == 0 - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_chunk_zero() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_chunk_mapping(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_chunk_mapping(2) diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_tp_with_zero.py similarity index 68% rename from tests/test_tensor/test_zero_optim.py rename to tests/test_tensor/test_tp_with_zero.py index b08ceed32..70cb837d8 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini import ChunkManager +from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from functools import partial from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal from tests.components_to_test.registry import non_distributed_component_funcs @@ -21,20 +21,20 @@ from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, Compute from tests.test_tensor.model.test_gpt2 import init_megatron_spec -def check_param_equal(model, torch_model, pg: ProcessGroup): - for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): - if p.storage().size() > 0: - assert p.dtype == torch.float16 - assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(), - pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}' +def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup): + zero_dict = model.state_dict(only_rank_0=False) + torch_dict = torch_model.state_dict() - -def check_grad_equal(model, torch_model, pg: ProcessGroup): - for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): - if p.grad is not None: - assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad, - pg.tp_local_rank(), pg.tp_world_size()), \ - f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}' + for key, value in torch_dict.items(): + # key is 'module.model.PARAMETER', so we truncate it + key = key[7:] + if key == 'model.lm_head.weight': + continue + assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) + temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) + assert tensor_shard_equal(value, temp_zero_value, pg.tp_local_rank(), pg.tp_world_size()), \ + "parameter '{}' has problem.".format(key) def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): @@ -62,10 +62,8 @@ def init_1d_col_spec(model, pg: ProcessGroup): p.set_tensor_spec(*spec) -@parameterize('use_chunk', [False, True]) -@parameterize('use_zero', [False, True]) @parameterize('placement_policy', ['cuda', 'cpu']) -def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): +def run_gpt(placement_policy, tp_init_spec_func=None): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -89,15 +87,20 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): if tp_init_spec_func: tp_init_spec_func(model, pg) - chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=use_zero, - init_device=GeminiManager.get_default_device(placement_policy)) + dp_world_size = pg.dp_world_size() + config_dict = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + config_dict[dp_world_size]['chunk_size'] = 5000 + config_dict[dp_world_size]['keep_gathered'] = False + if placement_policy != 'cuda': + init_device = torch.device('cpu') + else: + init_device = None + chunk_manager = ChunkManager(config_dict, init_device=init_device) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager) - optim = HybridAdam(model.parameters(), lr=1e-3) - optim = ZeroOptimizer(optim, model, initial_scale=1) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) @@ -105,7 +108,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) print(chunk_manager) - check_param_equal(model, torch_model, pg) + check_param(model, torch_model, pg) model.eval() torch_model.eval() @@ -115,13 +118,13 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): if i > 2: break input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg)) - logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask) + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids_colo, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) - assert tensor_equal(logits, torch_logits) - check_grad_equal(model, torch_model, pg) - optim.step() + assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + + zero_optim.step() torch_optim.step() - check_param_equal(model, torch_model, pg) + check_param(model, torch_model, pg) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_zero_optim_state_dict.py b/tests/test_zero/test_zero_optim_state_dict.py deleted file mode 100644 index cc67242c9..000000000 --- a/tests/test_zero/test_zero_optim_state_dict.py +++ /dev/null @@ -1,100 +0,0 @@ -import pytest -import colossalai -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.gemini import ChunkManager -from functools import partial -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ZeroDDP -from colossalai.nn.optimizer import HybridAdam -from colossalai.zero import ZeroOptimizer -from colossalai.testing import parameterize -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.tensor import ProcessGroup - - -def check_state(s1, s2): - for v1, v2 in zip(s1.values(), s2.values()): - if isinstance(v1, torch.Tensor): - v1 = v1.to(v2.device) - assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}' - else: - assert v1 == v2 - - -def check_load_state_dict(optim, torch_optim): - for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups): - for p, torch_p in zip(group['params'], torch_group['params']): - state = optim.optim.state[p] - torch_state = torch_optim.state[torch_p] - if p.storage().size() == 0: - assert len(state) == 0 - check_state(state, torch_state) - - -def check_state_dict(state_dict, torch_state_dict): - for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()): - assert k1 == k2 - check_state(s1, s2) - - -@parameterize('use_chunk', [False, True]) -@parameterize('use_zero', [False, True]) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -@parameterize('only_rank_0', [False, True]) -def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0): - get_components_func = non_distributed_component_funcs.get_callable('gpt2') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - with ColoInitContext(device=get_current_device()): - model = model_builder() - model = model.cuda() - torch_model = model_builder().cuda() - - pg = ProcessGroup() - - chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=use_zero, - init_device=GeminiManager.get_default_device(placement_policy)) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager) - optim = HybridAdam(model.parameters(), lr=1e-3) - optim = ZeroOptimizer(optim, model, initial_scale=1) - - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - - for p in torch_model.parameters(): - p.grad = torch.rand_like(p) - - torch_optim.step() - torch_state_dict = torch_optim.state_dict() - optim.load_state_dict(torch_state_dict) - check_load_state_dict(optim, torch_optim) - - state_dict = optim.state_dict(only_rank_0) - if not only_rank_0 or pg.rank() == 0: - check_state_dict(state_dict, torch_state_dict) - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_zero_optim_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_zero_optim_state_dict(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_zero_optim_state_dict(2)