diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/gemini/chunk/chunk.py index 648d48ec5..a9f0f7eae 100644 --- a/colossalai/gemini/chunk/chunk.py +++ b/colossalai/gemini/chunk/chunk.py @@ -1,552 +1,551 @@ -import torch -import torch.distributed as dist -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Dict, List - -from colossalai.utils import get_current_device -from colossalai.tensor import ProcessGroup as ColoProcessGroup - - -class TensorState(Enum): - FREE = 0 - COMPUTE = 1 - HOLD = 2 - HOLD_AFTER_BWD = 3 - READY_FOR_REDUCE = 4 - - -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), - (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), - (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) - - -@dataclass -class TensorInfo: - state: TensorState - offset: int - end: int - - -class ChunkFullError(Exception): - pass - - -def is_storage_empty(tensor: torch.Tensor) -> bool: - return tensor.storage().size() == 0 - - -def free_storage(tensor: torch.Tensor) -> None: - if not is_storage_empty(tensor): - tensor.storage().resize_(0) - - -def alloc_storage(tensor: torch.Tensor) -> None: - if is_storage_empty(tensor): - tensor.storage().resize_(tensor.numel()) - - -class Chunk: - - _total_number = 0 - - def __init__(self, - chunk_size: int, - process_group: ColoProcessGroup, - dtype: torch.dtype, - init_device: Optional[torch.device] = None, - keep_gathered: bool = False, - pin_memory: bool = False) -> None: - """ - Chunk: A container owning a piece of contiguous memory space for tensors - Here we use all-gather operation to gather the whole chunk. - Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. - It is designed to make the full use of communication and PCIE bandwidth. - - Args: - chunk_size (int): the number of elements in the chunk - process_group (ColoProcessGroup): the process group of this chunk - dtype (torch.dtype): the data type of the chunk - init_device (torch.device): optional, the device where the tensor is initialized - The default value is None, which is the current GPU - keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory - pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory - """ - self.count_id = Chunk._total_number - Chunk._total_number += 1 - - self.chunk_size = chunk_size - self.utilized_size = 0 - # Here, we use torch process group, - # since ColoProcessGroup might get deprecated soon - self.torch_pg = process_group.dp_process_group() - self.pg_size = dist.get_world_size(self.torch_pg) - self.pg_rank = dist.get_rank(self.torch_pg) - - # the chunk size should be able to be divied by the size of GPU - if not keep_gathered: - assert chunk_size % self.pg_size == 0 - self.shard_size = chunk_size // self.pg_size - self.shard_begin = self.shard_size * self.pg_rank - self.shard_end = self.shard_begin + self.shard_size - self.valid_end = self.shard_size - - self.dtype = dtype - device = init_device or get_current_device() - self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero - self.chunk_total = None # we force chunk_total located in CUDA - self.cuda_shard = None # using two attributes for the better interpretation - self.cpu_shard = None - self.is_gathered = True - - self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() - self.shard_mem = self.chunk_mem // self.pg_size - - # each tensor is associated with a TensorInfo to track meta info - self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} - # the total number of all tensors - self.num_tensors = 0 - # monitor the states of all tensors - self.tensors_state_monitor: Dict[TensorState, int] = dict() - for state in TensorState: - self.tensors_state_monitor[state] = 0 - - # some chunks can keep gathered all the time - # so their computation patterns are the same as that of the parameters in DDP - self.keep_gathered = keep_gathered - if self.keep_gathered: - pin_memory = False # since this chunk is gathered, it doesn't need to pin - - # if pin_memory is True, we allocate a piece of CPU pin-memory - # for it all the time - self.pin_memory = pin_memory - - # we introduce the paired chunk here - # it refers to another chunk having the same parameters - # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk - self.paired_chunk = None - # if this chunk is synchronized with the optimizer, the flag is True - self.optim_sync_flag = True - # if the cpu_shard has been visited during the training step, the flag is True - self.cpu_vis_flag = False - - @property - def memory_usage(self) -> Dict[str, int]: - cuda_memory = 0 - cpu_memory = 0 - - if self.chunk_temp is not None: - # this chunk is not closed - if self.chunk_temp.device.type == 'cuda': - cuda_memory += self.chunk_mem - else: - cpu_memory += self.chunk_mem - else: - if self.is_gathered: - cuda_memory += self.chunk_mem - if self.cuda_shard is not None: - cuda_memory += self.shard_mem - if self.cpu_shard is not None: - cpu_memory += self.shard_mem - - return dict(cuda=cuda_memory, cpu=cpu_memory) - - @property - def device_type(self) -> str: - if self.chunk_temp is not None: - return self.chunk_temp.device.type - else: - if self.is_gathered: - return 'cuda' - elif self.cuda_shard is not None: - return 'cuda' - else: - return 'cpu' - - @property - def payload(self) -> torch.Tensor: - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - return self.chunk_total - elif self.cuda_shard is not None: - return self.cuda_shard - else: - return self.cpu_shard - - @property - def payload_mem(self) -> int: - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - return self.chunk_mem - else: - return self.shard_mem - - @property - def can_move(self) -> bool: - return not self.is_gathered - - @property - def can_release(self) -> bool: - if self.keep_gathered: - return False - else: - return self.tensors_state_monitor[TensorState.HOLD] + \ - self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors - - @property - def can_reduce(self): - return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors - - @property - def has_inf_or_nan(self) -> bool: - """Check if the chunk has inf or nan values in CUDA. - """ - if self.is_gathered: - valid_tensor = self.chunk_total[:self.utilized_size] - else: - assert self.cuda_shard is not None # only check in CUDA - valid_tensor = self.cuda_shard[:self.valid_end] - - return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() - - def append_tensor(self, tensor: torch.Tensor): - """Add a tensor to the chunk. - - Args: - tensor (torch.Tensor): a tensor to be added to the chunk - """ - # sanity check - assert self.chunk_temp is not None - assert tensor.dtype == self.dtype - - new_utilized_size = self.utilized_size + tensor.numel() - # raise exception when the chunk size is exceeded - if new_utilized_size > self.chunk_size: - raise ChunkFullError - - self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) - assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) - - # record all the information about the tensor - self.num_tensors += 1 - tensor_state = TensorState.HOLD - self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) - self.tensors_state_monitor[tensor_state] += 1 - self.utilized_size = new_utilized_size - - def close_chunk(self, shard_dev: Optional[torch.device] = None): - """Close the chunk. Any tensor can't be appended to a closed chunk later. - - Args: - shard_dev: the device where the shard locates - """ - # sanity check - assert self.chunk_temp is not None - - # calculate the valid end for each shard - if self.utilized_size <= self.shard_begin: - self.valid_end = 0 - elif self.utilized_size < self.shard_end: - self.valid_end = self.utilized_size - self.shard_begin - - if self.chunk_temp.device.type == 'cpu': - self.chunk_total = self.chunk_temp.to(get_current_device()) - self.__update_tensors_ptr() - else: - self.chunk_total = self.chunk_temp - self.chunk_temp = None - - self.__scatter() - - if self.keep_gathered: - if shard_dev is None: - shard_dev = get_current_device() - else: - assert shard_dev.type == 'cuda' - elif shard_dev is None: - shard_dev = torch.device('cpu') - - if self.pin_memory or shard_dev.type == 'cpu': - self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) - self.cpu_shard.copy_(self.cuda_shard) - self.cpu_vis_flag = True # cpu_shard has been visited - - if shard_dev.type == 'cpu': - self.cuda_shard = None - - def shard_move(self, device: torch.device, force_copy: bool = False): - """Move the shard tensor in the chunk. - - Args: - device: the device to which the shard will move - force_copy: if True, copy function is called mandatorily - """ - # sanity check - assert not self.is_gathered - # when the current chunk is not synchronized with the optimizer - # just use another way for the movement - if not self.optim_sync_flag: - assert device.type == 'cuda', "each chunk should first be moved to CUDA" - self.__paired_shard_move() - self.optim_sync_flag = True - return - - if device.type == 'cuda': - assert device == get_current_device(), "can't move chunk to another device" - - if self.cuda_shard: - return - - self.cuda_shard = self.cpu_shard.to(get_current_device()) - - if not self.pin_memory: - self.cpu_shard = None - elif device.type == 'cpu': - if self.cuda_shard is None: - return - - if self.pin_memory: - if force_copy or not self.cpu_vis_flag: - self.cpu_shard.copy_(self.cuda_shard) - # if cpu_shard has been visited - # copy operation is not need - else: - self.cpu_shard = self.cuda_shard.cpu() - self.cpu_vis_flag = True - self.cuda_shard = None - else: - raise NotImplementedError - - def access_chunk(self): - """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. - """ - # sanity check - assert self.chunk_temp is None - - if not self.is_gathered: - self.__gather() - self.__update_tensors_ptr() - - def release_chunk(self): - """Release the usable chunk. It's an operation done in CUDA. - """ - # sanity check - assert self.chunk_temp is None - - if self.is_gathered: - self.__scatter() - - def reduce(self): - """Reduce scatter all the gradients. It's an operation done in CUDA. - """ - # sanity check - assert self.is_gathered - - if self.pg_size == 1: - # tricky code here - # just move chunk_total to cuda_shard - # the communication is not necessary - self.__scatter() - elif self.keep_gathered: - # we use all-reduce here - dist.all_reduce(self.chunk_total, group=self.torch_pg) - else: - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) - - input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0)) - dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) - - free_storage(self.chunk_total) - self.is_gathered = False - self.__update_tensors_state(TensorState.HOLD) - - def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: - """ - Make a transition of the tensor into the next state. - - Args: - tensor (torch.Tensor): a torch Tensor object. - tensor_state (TensorState): the target state for transition. - """ - - # As the gradient hook can be triggered either before or after post-backward - # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce - # or compute -> ready_for_reduce -> hold_after_bwd - # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd - # this function only apply valid state transformation - # invalid calls will be ignored and nothing changes - if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: - return - self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) - - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: - """ - Copy data slice to the memory space indexed by the input tensor in the chunk. - - Args: - tensor (torch.Tensor): the tensor used to retrive meta information - data_slice (torch.Tensor): the tensor to be copied to the chunk - """ - # sanity check - assert self.is_gathered - - tensor_info = self.tensors_info[tensor] - self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) - tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) - - def get_valid_length(self) -> int: - """Get the valid length of the chunk's payload. - """ - if self.keep_gathered: - return self.utilized_size - else: - return self.valid_end - - def init_pair(self, friend_chunk: 'Chunk') -> None: - """Initialize the paired chunk. - """ - if self.paired_chunk is None and friend_chunk.paired_chunk is None: - self.paired_chunk = friend_chunk - friend_chunk.paired_chunk = self - else: - assert self.paired_chunk is friend_chunk - assert friend_chunk.paired_chunk is self - - def optim_update(self) -> None: - """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. - """ - # sanity check - assert self.paired_chunk is not None - - friend_chunk = self.paired_chunk - if self.is_gathered is True: - assert friend_chunk.is_gathered is True - self.chunk_total.copy_(friend_chunk.chunk_total) - self.optim_sync_flag = True - elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': - self.cuda_shard.copy_(friend_chunk.cuda_shard) - self.optim_sync_flag = True - self.cpu_vis_flag = False - else: - # optim_sync_flag is set to False - # see shard_move function for more details - assert friend_chunk.device_type == 'cpu' - assert self.device_type == 'cpu' - self.optim_sync_flag = False - self.cpu_vis_flag = False - - def get_tensors(self) -> List[torch.Tensor]: - return list(self.tensors_info.keys()) - - def __gather(self): - if not self.is_gathered: - # sanity check - assert self.cuda_shard is not None - - alloc_storage(self.chunk_total) - gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) - - self.cuda_shard = None - self.is_gathered = True - - def __scatter(self): - if self.keep_gathered: - return - - if self.is_gathered: - # sanity check - assert self.cuda_shard is None - - self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device) - - self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end]) - - free_storage(self.chunk_total) - self.is_gathered = False - - def __paired_shard_move(self): - assert self.paired_chunk is not None, "chunks should be paired before training" - optim_chunk = self.paired_chunk - assert self.chunk_size == optim_chunk.chunk_size - - # only be called when optimizer state is in CPU memory - # the grad and param should be in the same device - assert self.cuda_shard is None - temp = optim_chunk.cpu_shard.to(get_current_device()) - # avoid to transform FP32 in CPU - self.cuda_shard = temp.to(self.dtype) - - if not self.pin_memory: - self.cpu_shard = None - - def __update_tensors_ptr(self) -> None: - # sanity check - assert self.is_gathered - assert type(self.chunk_total) == torch.Tensor - - for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) - - def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): - self.tensors_state_monitor[tensor_info.state] -= 1 - tensor_info.state = next_state - self.tensors_state_monitor[tensor_info.state] += 1 - - def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): - for tensor_info in self.tensors_info.values(): - if prev_state is None or tensor_info.state == prev_state: - self.__update_one_tensor_info(tensor_info, next_state) - - def __hash__(self) -> int: - return hash(id(self)) - - def __eq__(self, __o: object) -> bool: - return self is __o - - def __repr__(self, detailed: bool = True): - output = [ - "Chunk Information:\n", - "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, - self.pg_size), - "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( - self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) - ] - - def print_tensor(tensor, prefix=''): - output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, - tensor.device)) - - if self.chunk_temp is not None: - output.append("\tchunk temp:\n") - print_tensor(tensor=self.chunk_temp, prefix='\t\t') - - if self.chunk_total is not None and self.chunk_total.storage().size() > 0: - output.append("\tchunk total:\n") - print_tensor(tensor=self.chunk_total, prefix='\t\t') - - if self.cuda_shard is not None: - output.append("\tcuda shard:\n") - print_tensor(tensor=self.cuda_shard, prefix='\t\t') - - if self.cpu_shard is not None: - output.append("\tcpu shard:\n") - print_tensor(tensor=self.cpu_shard, prefix='\t\t') - - memory_info = self.memory_usage - output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) - - if detailed: - output.append("\ttensor state monitor:\n") - for st in TensorState: - output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st])) - - return ''.join(output) +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.utils import get_current_device + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, + TensorState.HOLD)) + + +@dataclass +class TensorInfo: + state: TensorState + offset: int + end: int + + +class ChunkFullError(Exception): + pass + + +def is_storage_empty(tensor: torch.Tensor) -> bool: + return tensor.storage().size() == 0 + + +def free_storage(tensor: torch.Tensor) -> None: + if not is_storage_empty(tensor): + tensor.storage().resize_(0) + + +def alloc_storage(tensor: torch.Tensor) -> None: + if is_storage_empty(tensor): + tensor.storage().resize_(tensor.numel()) + + +class Chunk: + + _total_number = 0 + + def __init__(self, + chunk_size: int, + process_group: ColoProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + cpu_shard_init: bool = False, + keep_gathered: bool = False, + pin_memory: bool = False) -> None: + """ + Chunk: A container owning a piece of contiguous memory space for tensors + Here we use all-gather operation to gather the whole chunk. + Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters. + It is designed to make the full use of communication and PCIE bandwidth. + + Args: + chunk_size (int): the number of elements in the chunk + process_group (ColoProcessGroup): the process group of this chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, the device where the tensor is initialized + The default value is None, which is the current GPU + keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory + pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory + """ + self.count_id = Chunk._total_number + Chunk._total_number += 1 + + self.chunk_size = chunk_size + self.utilized_size = 0 + # Here, we use torch process group, + # since ColoProcessGroup might get deprecated soon + self.torch_pg = process_group.dp_process_group() + self.pg_size = dist.get_world_size(self.torch_pg) + self.pg_rank = dist.get_rank(self.torch_pg) + + # the chunk size should be able to be divied by the size of GPU + if not keep_gathered: + assert chunk_size % self.pg_size == 0 + self.shard_size = chunk_size // self.pg_size + self.shard_begin = self.shard_size * self.pg_rank + self.shard_end = self.shard_begin + self.shard_size + self.valid_end = self.shard_size + + self.dtype = dtype + device = init_device or get_current_device() + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + self.chunk_total = None # we force chunk_total located in CUDA + self.cuda_shard = None # using two attributes for the better interpretation + self.cpu_shard = None + self.is_gathered = True + + # configure the init deivce of the shard + # no-offload default: fp16, fp32 -> CUDA + # offload default: fp16, fp32 -> CPU + self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device() + + self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() + self.shard_mem = self.chunk_mem // self.pg_size + + # each tensor is associated with a TensorInfo to track meta info + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + # the total number of all tensors + self.num_tensors = 0 + # monitor the states of all tensors + self.tensors_state_monitor: Dict[TensorState, int] = dict() + for state in TensorState: + self.tensors_state_monitor[state] = 0 + + # some chunks can keep gathered all the time + # so their computation patterns are the same as that of the parameters in DDP + self.keep_gathered = keep_gathered + if self.keep_gathered: + pin_memory = False # since this chunk is gathered, it doesn't need to pin + + # if pin_memory is True, we allocate a piece of CPU pin-memory + # for it all the time + self.pin_memory = pin_memory + + # we introduce the paired chunk here + # it refers to another chunk having the same parameters + # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk + self.paired_chunk = None + # if this chunk is synchronized with the optimizer, the flag is True + self.optim_sync_flag = True + # if the cpu_shard has been visited during the training step, the flag is True + self.cpu_vis_flag = False + + @property + def memory_usage(self) -> Dict[str, int]: + cuda_memory = 0 + cpu_memory = 0 + + if self.chunk_temp is not None: + # this chunk is not closed + if self.chunk_temp.device.type == 'cuda': + cuda_memory += self.chunk_mem + else: + cpu_memory += self.chunk_mem + else: + if self.is_gathered: + cuda_memory += self.chunk_mem + if self.cuda_shard is not None: + cuda_memory += self.shard_mem + if self.cpu_shard is not None: + cpu_memory += self.shard_mem + + return dict(cuda=cuda_memory, cpu=cpu_memory) + + @property + def device_type(self) -> str: + if self.chunk_temp is not None: + return self.chunk_temp.device.type + else: + if self.is_gathered: + return 'cuda' + elif self.cuda_shard is not None: + return 'cuda' + else: + return 'cpu' + + @property + def payload(self) -> torch.Tensor: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.chunk_total + elif self.cuda_shard is not None: + return self.cuda_shard + else: + return self.cpu_shard + + @property + def payload_mem(self) -> int: + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + return self.chunk_mem + else: + return self.shard_mem + + @property + def can_move(self) -> bool: + return not self.is_gathered + + @property + def can_release(self) -> bool: + if self.keep_gathered: + return False + else: + return self.tensors_state_monitor[TensorState.HOLD] + \ + self.tensors_state_monitor[TensorState.HOLD_AFTER_BWD] == self.num_tensors + + @property + def can_reduce(self): + return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors + + @property + def has_inf_or_nan(self) -> bool: + """Check if the chunk has inf or nan values in CUDA. + """ + if self.is_gathered: + valid_tensor = self.chunk_total[:self.utilized_size] + else: + assert self.cuda_shard is not None # only check in CUDA + valid_tensor = self.cuda_shard[:self.valid_end] + + return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + + def append_tensor(self, tensor: torch.Tensor): + """Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ + # sanity check + assert self.chunk_temp is not None + assert tensor.dtype == self.dtype + + new_utilized_size = self.utilized_size + tensor.numel() + # raise exception when the chunk size is exceeded + if new_utilized_size > self.chunk_size: + raise ChunkFullError + + self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" + tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + + # record all the information about the tensor + self.num_tensors += 1 + tensor_state = TensorState.HOLD + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.tensors_state_monitor[tensor_state] += 1 + self.utilized_size = new_utilized_size + + def close_chunk(self): + """Close the chunk. Any tensor can't be appended to a closed chunk later. + """ + # sanity check + assert self.chunk_temp is not None + + # calculate the valid end for each shard + if self.utilized_size <= self.shard_begin: + self.valid_end = 0 + elif self.utilized_size < self.shard_end: + self.valid_end = self.utilized_size - self.shard_begin + + if self.chunk_temp.device.type == 'cpu': + self.chunk_total = self.chunk_temp.to(get_current_device()) + self.__update_tensors_ptr() + else: + self.chunk_total = self.chunk_temp + self.chunk_temp = None + + self.__scatter() + # always gathered chunk does not have shard + if self.keep_gathered: + return + + if self.pin_memory or self.shard_device.type == 'cpu': + self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) + self.cpu_shard.copy_(self.cuda_shard) + self.cpu_vis_flag = True # cpu_shard has been visited + + if self.shard_device.type == 'cpu': + self.cuda_shard = None + + def shard_move(self, device: torch.device, force_copy: bool = False): + """Move the shard tensor in the chunk. + + Args: + device: the device to which the shard will move + force_copy: if True, copy function is called mandatorily + """ + # sanity check + assert not self.is_gathered + # when the current chunk is not synchronized with the optimizer + # just use another way for the movement + if not self.optim_sync_flag: + assert device.type == 'cuda', "each chunk should first be moved to CUDA" + self.__paired_shard_move() + self.optim_sync_flag = True + return + + if device.type == 'cuda': + assert device == get_current_device(), "can't move chunk to another device" + + if self.cuda_shard: + return + + self.cuda_shard = self.cpu_shard.to(get_current_device()) + + if not self.pin_memory: + self.cpu_shard = None + elif device.type == 'cpu': + if self.cuda_shard is None: + return + + if self.pin_memory: + if force_copy or not self.cpu_vis_flag: + self.cpu_shard.copy_(self.cuda_shard) + # if cpu_shard has been visited + # copy operation is not need + else: + self.cpu_shard = self.cuda_shard.cpu() + self.cpu_vis_flag = True + self.cuda_shard = None + else: + raise NotImplementedError + + def access_chunk(self): + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if not self.is_gathered: + self.__gather() + self.__update_tensors_ptr() + + def release_chunk(self): + """Release the usable chunk. It's an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + self.__scatter() + + def reduce(self): + """Reduce scatter all the gradients. It's an operation done in CUDA. + """ + # sanity check + assert self.is_gathered + + if self.pg_size == 1: + # tricky code here + # just move chunk_total to cuda_shard + # the communication is not necessary + self.__scatter() + elif self.keep_gathered: + # we use all-reduce here + dist.all_reduce(self.chunk_total, group=self.torch_pg) + else: + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device()) + + input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0)) + dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg) + + free_storage(self.chunk_total) + self.is_gathered = False + self.__update_tensors_state(TensorState.HOLD) + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ + + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + return + self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) + tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def get_valid_length(self) -> int: + """Get the valid length of the chunk's payload. + """ + if self.keep_gathered: + return self.utilized_size + else: + return self.valid_end + + def init_pair(self, friend_chunk: 'Chunk') -> None: + """Initialize the paired chunk. + """ + if self.paired_chunk is None and friend_chunk.paired_chunk is None: + self.paired_chunk = friend_chunk + friend_chunk.paired_chunk = self + else: + assert self.paired_chunk is friend_chunk + assert friend_chunk.paired_chunk is self + + def optim_update(self) -> None: + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. + """ + # sanity check + assert self.paired_chunk is not None + + friend_chunk = self.paired_chunk + if self.is_gathered is True: + assert friend_chunk.is_gathered is True + self.chunk_total.copy_(friend_chunk.chunk_total) + self.optim_sync_flag = True + elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + self.cuda_shard.copy_(friend_chunk.cuda_shard) + self.optim_sync_flag = True + self.cpu_vis_flag = False + else: + # optim_sync_flag is set to False + # see shard_move function for more details + assert friend_chunk.device_type == 'cpu' + assert self.device_type == 'cpu' + self.optim_sync_flag = False + self.cpu_vis_flag = False + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) + + def __gather(self): + if not self.is_gathered: + # sanity check + assert self.cuda_shard is not None + + alloc_storage(self.chunk_total) + gather_list = list(torch.chunk(input=self.chunk_total, chunks=self.pg_size, dim=0)) + dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + + self.cuda_shard = None + self.is_gathered = True + + def __scatter(self): + if self.keep_gathered: + return + + if self.is_gathered: + # sanity check + assert self.cuda_shard is None + + self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.chunk_total.device) + + self.cuda_shard.copy_(self.chunk_total[self.shard_begin:self.shard_end]) + + free_storage(self.chunk_total) + self.is_gathered = False + + def __paired_shard_move(self): + assert self.paired_chunk is not None, "chunks should be paired before training" + optim_chunk = self.paired_chunk + assert self.chunk_size == optim_chunk.chunk_size + + # only be called when optimizer state is in CPU memory + # the grad and param should be in the same device + assert self.cuda_shard is None + temp = optim_chunk.cpu_shard.to(get_current_device()) + # avoid to transform FP32 in CPU + self.cuda_shard = temp.to(self.dtype) + + if not self.pin_memory: + self.cpu_shard = None + + def __update_tensors_ptr(self) -> None: + # sanity check + assert self.is_gathered + assert type(self.chunk_total) == torch.Tensor + + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): + self.tensors_state_monitor[tensor_info.state] -= 1 + tensor_info.state = next_state + self.tensors_state_monitor[tensor_info.state] += 1 + + def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + self.__update_one_tensor_info(tensor_info, next_state) + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, __o: object) -> bool: + return self is __o + + def __repr__(self, detailed: bool = True): + output = [ + "Chunk Information:\n", + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, + self.pg_size), + "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( + self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) + ] + + def print_tensor(tensor, prefix=''): + output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, + tensor.device)) + + if self.chunk_temp is not None: + output.append("\tchunk temp:\n") + print_tensor(tensor=self.chunk_temp, prefix='\t\t') + + if self.chunk_total is not None and self.chunk_total.storage().size() > 0: + output.append("\tchunk total:\n") + print_tensor(tensor=self.chunk_total, prefix='\t\t') + + if self.cuda_shard is not None: + output.append("\tcuda shard:\n") + print_tensor(tensor=self.cuda_shard, prefix='\t\t') + + if self.cpu_shard is not None: + output.append("\tcpu shard:\n") + print_tensor(tensor=self.cpu_shard, prefix='\t\t') + + memory_info = self.memory_usage + output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) + + if detailed: + output.append("\ttensor state monitor:\n") + for st in TensorState: + output.append("\t\t# of {}: {}\n".format(st, self.tensors_state_monitor[st])) + + return ''.join(output) diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index 4a2474a63..ac73105a0 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -1,230 +1,237 @@ -import torch -from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable -from collections import deque - -from colossalai.utils import get_current_device -from colossalai.tensor import ColoTensor -from colossalai.gemini.chunk import ChunkFullError, TensorState, Chunk - - -class ChunkManager: - """ - A manager class to manipulate the tensors in chunks. - - Args: - chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. - init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. - """ - - def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None: - - self.device = init_device or get_current_device() - self.size_config: Dict[int, int] = dict() - self.kwargs_config = chunk_configuration - for k, v in self.kwargs_config.items(): - self.size_config[k] = v.pop('chunk_size') - v['init_device'] = self.device - - self.chunk_groups: Dict[str, Deque] = dict() - self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() - self.accessed_chunks: Set[Chunk] = set() - self.accessed_mem: int = 0 - self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - - def append_tensor(self, tensor: ColoTensor, group_type: str, config_key: int, pin_memory: bool = False) -> None: - """Append a tensor to a chunk. - - Args: - tensor: the tensor appended to the chunk - group_type: the data type of the group - config_key: the key of the group's name, usually the size of the dp world - pin_memory: whether the chunk is pinned in the cpu memory - """ - assert tensor not in self.tensor_chunk_map - assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" - assert config_key in self.size_config - - chunk_size = self.size_config[config_key] - chunk_kwargs = self.kwargs_config[config_key] - group_name = "{}_{}".format(group_type, config_key) - chunk_group = self.__get_chunk_group(group_name) - - try: - # append the tensor to the last chunk - chunk_group[-1].append_tensor(tensor) - except (IndexError, ChunkFullError): - # the except statement will be triggered when there is no chunk or - # the last chunk in the chunk group is full - # this will create a new chunk and allocate this chunk to its corresponding process - if chunk_group: - # the chunk group is not empty - # close the last chunk - self.__close_one_chunk(chunk_group[-1]) - - if tensor.numel() > chunk_size: - chunk_size = tensor.numel() - chunk = Chunk( - chunk_size=chunk_size, - process_group=tensor.process_group, - dtype=tensor.dtype, - pin_memory=pin_memory, - **chunk_kwargs, - ) - - chunk_group.append(chunk) - chunk.append_tensor(tensor) - self.__add_memory_usage(chunk.memory_usage) - - self.tensor_chunk_map[tensor] = chunk_group[-1] - - def close_all_groups(self): - """Close all the chunks of all groups. - """ - for group_name in self.chunk_groups: - self.__close_one_chunk(self.chunk_groups[group_name][-1]) - - def access_chunk(self, chunk: Chunk) -> None: - """Make the chunk can be used for calculation. - """ - if chunk in self.accessed_chunks: - return - self.__sub_memroy_usage(chunk.memory_usage) - if chunk.device_type == 'cpu': - chunk.shard_move(get_current_device()) - self.__add_accessed_chunk(chunk) - self.__add_memory_usage(chunk.memory_usage) - - def release_chunk(self, chunk: Chunk) -> None: - """Scatter the chunk in CUDA. - """ - if chunk not in self.accessed_chunks: - return - if chunk.can_release: - self.__sub_memroy_usage(chunk.memory_usage) - self.__sub_accessed_chunk(chunk) - self.__add_memory_usage(chunk.memory_usage) - - def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: - """Move the shard of the chunk to the target device. - """ - if not chunk.can_move or chunk.device_type == device.type: - return - self.__sub_memroy_usage(chunk.memory_usage) - chunk.shard_move(device, force_copy) - self.__add_memory_usage(chunk.memory_usage) - - def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: - """Transit tensor state according to pre-defined state machine. - """ - chunk = self.tensor_chunk_map[tensor] - chunk.tensor_trans_state(tensor, state) - - def reduce_chunk(self, chunk: Chunk) -> bool: - """Reduce or all reduce the chunk. - """ - if not chunk.can_reduce: - return False - self.__sub_memroy_usage(chunk.memory_usage) - chunk.reduce() - self.__sub_accessed_chunk(chunk) - self.__add_memory_usage(chunk.memory_usage) - return True - - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: - """ - Copy data to the chunk. - - Args: - tensor (torch.Tensor): the tensor used to retrive meta information - data (torch.Tensor): the tensor to be copied to the chunk - """ - chunk = self.tensor_chunk_map[tensor] - chunk.copy_tensor_to_chunk_slice(tensor, data) - - def get_chunk(self, tensor: torch.Tensor) -> Chunk: - """ - Return the chunk owning the tensor. - - Args: - tensor (torch.Tensor): a torch tensor object - """ - return self.tensor_chunk_map[tensor] - - def get_cuda_movable_chunks(self) -> List[Chunk]: - """ - Get all chunks that can be moved. - """ - chunk_list = [] - for chunk in self.accessed_chunks: - if chunk.can_release: - chunk_list.append(chunk) - chunk_list.sort(key=lambda x: x.count_id) - return chunk_list - - def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: - """ - Get all chunks owning the input tensors. - - Args: - tensors (Iterable[torch.Tensor]): the tensors used to look for chunks - """ - chunks = [] - for tensor in tensors: - chunk = self.get_chunk(tensor) - if chunk not in chunks: - chunks.append(chunk) - return tuple(chunks) - - def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: - """Add extern static tensor to chunk manager. - Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. - They are "static", which means their shape, dtype, device never change. - Thus, their memory usage never changes. - - Args: - tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. - """ - assert tensor not in self.tensor_chunk_map - self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() - - def __repr__(self) -> str: - msg = [ - 'Chunk Manager Information:\n', - 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' - ] - for group_name, group in self.chunk_groups.items(): - msg.append(f'Group {group_name}:\n') - for i, chunk in enumerate(group): - msg.append(f'[{i}] {chunk}\n') - return ''.join(msg) - - def __get_chunk_group(self, group_name: str) -> Deque: - """Register a chunk group. - """ - if group_name not in self.chunk_groups: - self.chunk_groups[group_name] = deque() - return self.chunk_groups[group_name] - - def __close_one_chunk(self, chunk: Chunk): - device = get_current_device() if chunk.keep_gathered else self.device # keep gathered chunk in cuda - self.__sub_memroy_usage(chunk.memory_usage) - chunk.close_chunk(device) - self.__add_memory_usage(chunk.memory_usage) - - def __sub_memroy_usage(self, usage: Dict[str, int]): - for k, v in usage.items(): - self.total_mem[k] -= v - - def __add_memory_usage(self, usage: Dict[str, int]): - for k, v in usage.items(): - self.total_mem[k] += v - - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() - self.accessed_chunks.add(chunk) - self.accessed_mem += chunk.chunk_mem - - def __sub_accessed_chunk(self, chunk: Chunk): - chunk.release_chunk() - self.accessed_chunks.remove(chunk) - self.accessed_mem -= chunk.chunk_mem +from collections import deque +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple + +import torch + +from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState +from colossalai.tensor import ColoTensor +from colossalai.utils import get_current_device + + +class ChunkManager: + """ + A manager class to manipulate the tensors in chunks. + + Args: + chunk_configuration (Dict[int, Dict]): the configuration dictionary of this chunk manager. + init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. + """ + + def __init__(self, chunk_configuration: Dict[int, Dict], init_device: Optional[torch.device] = None) -> None: + + self.device = init_device or get_current_device() + self.size_config: Dict[int, int] = dict() + self.kwargs_config = chunk_configuration + for k, v in self.kwargs_config.items(): + self.size_config[k] = v.pop('chunk_size') + v['init_device'] = self.device + + self.chunk_groups: Dict[str, Deque] = dict() + self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() + self.accessed_chunks: Set[Chunk] = set() + self.accessed_mem: int = 0 + self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} + + def append_tensor(self, + tensor: ColoTensor, + group_type: str, + config_key: int, + cpu_offload: bool = False, + pin_memory: bool = False) -> None: + """Append a tensor to a chunk. + + Args: + tensor: the tensor appended to the chunk + group_type: the data type of the group + config_key: the key of the group's name, usually the size of the dp world + cpu_offload: if True, the chunk will be closed on CPU + pin_memory: whether the chunk is pinned in the cpu memory + """ + assert tensor not in self.tensor_chunk_map + assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" + assert config_key in self.size_config + + chunk_size = self.size_config[config_key] + chunk_kwargs = self.kwargs_config[config_key] + group_name = "{}_{}".format(group_type, config_key) + chunk_group = self.__get_chunk_group(group_name) + + try: + # append the tensor to the last chunk + chunk_group[-1].append_tensor(tensor) + except (IndexError, ChunkFullError): + # the except statement will be triggered when there is no chunk or + # the last chunk in the chunk group is full + # this will create a new chunk and allocate this chunk to its corresponding process + if chunk_group: + # the chunk group is not empty + # close the last chunk + self.__close_one_chunk(chunk_group[-1]) + + if tensor.numel() > chunk_size: + chunk_size = tensor.numel() + chunk = Chunk( + chunk_size=chunk_size, + process_group=tensor.process_group, + dtype=tensor.dtype, + cpu_shard_init=cpu_offload, + pin_memory=pin_memory, + **chunk_kwargs, + ) + + chunk_group.append(chunk) + chunk.append_tensor(tensor) + self.__add_memory_usage(chunk.memory_usage) + + self.tensor_chunk_map[tensor] = chunk_group[-1] + + def close_all_groups(self): + """Close all the chunks of all groups. + """ + for group_name in self.chunk_groups: + self.__close_one_chunk(self.chunk_groups[group_name][-1]) + + def access_chunk(self, chunk: Chunk) -> None: + """Make the chunk can be used for calculation. + """ + if chunk in self.accessed_chunks: + return + self.__sub_memroy_usage(chunk.memory_usage) + if chunk.device_type == 'cpu': + chunk.shard_move(get_current_device()) + self.__add_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def release_chunk(self, chunk: Chunk) -> None: + """Scatter the chunk in CUDA. + """ + if chunk not in self.accessed_chunks: + return + if chunk.can_release: + self.__sub_memroy_usage(chunk.memory_usage) + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + + def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: + """Move the shard of the chunk to the target device. + """ + if not chunk.can_move or chunk.device_type == device.type: + return + self.__sub_memroy_usage(chunk.memory_usage) + chunk.shard_move(device, force_copy) + self.__add_memory_usage(chunk.memory_usage) + + def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: + """Transit tensor state according to pre-defined state machine. + """ + chunk = self.tensor_chunk_map[tensor] + chunk.tensor_trans_state(tensor, state) + + def reduce_chunk(self, chunk: Chunk) -> bool: + """Reduce or all reduce the chunk. + """ + if not chunk.can_reduce: + return False + self.__sub_memroy_usage(chunk.memory_usage) + chunk.reduce() + self.__sub_accessed_chunk(chunk) + self.__add_memory_usage(chunk.memory_usage) + return True + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: + """ + Copy data to the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data (torch.Tensor): the tensor to be copied to the chunk + """ + chunk = self.tensor_chunk_map[tensor] + chunk.copy_tensor_to_chunk_slice(tensor, data) + + def get_chunk(self, tensor: torch.Tensor) -> Chunk: + """ + Return the chunk owning the tensor. + + Args: + tensor (torch.Tensor): a torch tensor object + """ + return self.tensor_chunk_map[tensor] + + def get_cuda_movable_chunks(self) -> List[Chunk]: + """ + Get all chunks that can be moved. + """ + chunk_list = [] + for chunk in self.accessed_chunks: + if chunk.can_release: + chunk_list.append(chunk) + chunk_list.sort(key=lambda x: x.count_id) + return chunk_list + + def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: + """ + Get all chunks owning the input tensors. + + Args: + tensors (Iterable[torch.Tensor]): the tensors used to look for chunks + """ + chunks = [] + for tensor in tensors: + chunk = self.get_chunk(tensor) + if chunk not in chunks: + chunks.append(chunk) + return tuple(chunks) + + def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: + """Add extern static tensor to chunk manager. + Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them. + They are "static", which means their shape, dtype, device never change. + Thus, their memory usage never changes. + + Args: + tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. + """ + assert tensor not in self.tensor_chunk_map + self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + + def __repr__(self) -> str: + msg = [ + 'Chunk Manager Information:\n', + 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + ] + for group_name, group in self.chunk_groups.items(): + msg.append(f'Group {group_name}:\n') + for i, chunk in enumerate(group): + msg.append(f'[{i}] {chunk}\n') + return ''.join(msg) + + def __get_chunk_group(self, group_name: str) -> Deque: + """Register a chunk group. + """ + if group_name not in self.chunk_groups: + self.chunk_groups[group_name] = deque() + return self.chunk_groups[group_name] + + def __close_one_chunk(self, chunk: Chunk): + self.__sub_memroy_usage(chunk.memory_usage) + chunk.close_chunk() + self.__add_memory_usage(chunk.memory_usage) + + def __sub_memroy_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] -= v + + def __add_memory_usage(self, usage: Dict[str, int]): + for k, v in usage.items(): + self.total_mem[k] += v + + def __add_accessed_chunk(self, chunk: Chunk): + chunk.access_chunk() + self.accessed_chunks.add(chunk) + self.accessed_mem += chunk.chunk_mem + + def __sub_accessed_chunk(self, chunk: Chunk): + chunk.release_chunk() + self.accessed_chunks.remove(chunk) + self.accessed_mem -= chunk.chunk_mem diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 6d6b7425c..b001a2aee 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -1,9 +1,12 @@ -import torch import functools -from .memory_tracer.memstats_collector import MemStatsCollectorV2 -from typing import List, Optional, Tuple from time import time +from typing import List, Optional, Tuple + +import torch + from colossalai.gemini.chunk import Chunk, ChunkManager + +from .memory_tracer.memstats_collector import MemStatsCollectorV2 from .placement_policy import PlacementPolicyFactory @@ -25,6 +28,7 @@ class GeminiManager: def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None: assert placement_policy in PlacementPolicyFactory.get_polocy_names() + self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) self._chunk_manager = chunk_manager self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 5bce81708..d58a746b6 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -1,19 +1,22 @@ -import torch import itertools -import torch.distributed as dist -from functools import partial -from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 -from colossalai.tensor.param_op_hook import ParamOpHookManager -from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Dict, Iterable, List, Optional, Set -from colossalai.logging import get_dist_logger from collections import OrderedDict -from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec -from colossalai.tensor import ProcessGroup as ColoProcessGroup -from .reducer import Reducer +from functools import partial +from typing import Dict, Iterable, List, Optional, Set -from colossalai.gemini.chunk import TensorState, Chunk, ChunkManager +import torch +import torch.distributed as dist + +from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState +from colossalai.gemini.gemini_mgr import GeminiManager +from colossalai.logging import get_dist_logger from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.param_op_hook import ParamOpHookManager +from colossalai.utils import get_current_device +from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 + +from .reducer import Reducer try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -221,6 +224,7 @@ class ZeroDDP(ColoDDP): self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} + cpu_offload = self.gemini_manager.policy_name != 'cuda' # TODO: get param order and filter unused params for p in module.parameters(): assert isinstance(p, ColoParameter) @@ -232,10 +236,17 @@ class ZeroDDP(ColoDDP): fp32_data = p.data.float() fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) p.data = p.data.half() - dp_world_size = p.process_group.dp_world_size() - self.chunk_manager.append_tensor(p, 'fp16_param', dp_world_size, pin_memory) - self.chunk_manager.append_tensor(fp32_p, 'fp32_param', dp_world_size, pin_memory) + self.chunk_manager.append_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.append_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device self.chunk_manager.close_all_groups() @@ -247,6 +258,10 @@ class ZeroDDP(ColoDDP): chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32.init_pair(chunk_16) + # keep gathered chunks are in CUDA + if chunk_16.keep_gathered: + self.grads_device[p] = get_current_device() + self._logger = get_dist_logger() def forward(self, *args, **kwargs): diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index ce6d20c0e..2dd0de560 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,14 +1,15 @@ -from .op_wrapper import _COLOSSAL_OPS -from .const import TensorType from copy import copy -import torch from functools import lru_cache +from typing import Callable, Optional, Set -from colossalai.tensor import ColoTensorSpec -from colossalai.tensor import ProcessGroup, ReplicaSpec +import torch + +from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec from colossalai.tensor.dist_spec_mgr import DistSpecManager -from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern -from typing import Optional, Set, Callable +from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec + +from .const import TensorType +from .op_wrapper import _COLOSSAL_OPS @lru_cache(None) @@ -57,25 +58,26 @@ class ColoTensor(torch.Tensor): >>> pg = ProcessGroup() >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()) >>> # The tensor passed in is a tensor after sharding but not a global tensor. - >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), - >>> dims=[0], + >>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size), + >>> dims=[0], >>> num_partitions=[world_size]) >>> tensor_spec = ColoTensorSpec(pg, shard_spec) >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) - + Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()). """ + torch_minor = int(torch.__version__.split('.')[1]) def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': """ The signature of the __new__ has to be consistent with the torch.Tensor. - + Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (TensorSpec, optional): the tensor spec of initialization. - + Returns: ColoTensor: a ColoTensor wrappers the data. """ @@ -112,7 +114,7 @@ class ColoTensor(torch.Tensor): return self.process_group def set_process_group(self, pg: ProcessGroup): - """set_process_group + """set_process_group change the pg of the ColoTensor. Note that the valid use cases is limited. Only existing pg is DP and dist spec is REPLICaTE is valid. @@ -135,7 +137,7 @@ class ColoTensor(torch.Tensor): return self.process_group.tp_world_size() def set_dist_spec(self, dist_spec: _DistSpec): - """set_dist_spec + """set_dist_spec set dist spec and change the payloads. Args: @@ -166,6 +168,16 @@ class ColoTensor(torch.Tensor): if func in _COLOSSAL_OPS: func = _COLOSSAL_OPS[func] + if cls.torch_minor >= 12: + # in order to trigger pre-op hook in the forward of checkpoint module + # we have to capture the `backward` function + # and make sure that it does not in `torch._C.DisableTorchFunction()` context + if func is torch.Tensor.backward: + assert len(args) == 1 # only has 1 paramter + backward_tensor = torch.Tensor(args[0]) + tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} + return backward_tensor.backward(**tensor_kwargs) + with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in _get_my_nowrap_functions(): @@ -178,7 +190,7 @@ class ColoTensor(torch.Tensor): return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}' def _redistribute(self, dist_spec: _DistSpec) -> None: - """_redistribute + """_redistribute Note the function will not handle the logic of backward propagation! It is used during model tensor initializations as an internal function. @@ -191,12 +203,12 @@ class ColoTensor(torch.Tensor): self.dist_spec = dist_spec def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor': - """redistribute + """redistribute Redistribute the tensor among processes. The rule is like this: - + 1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the DP process group not changed. - + 2. If the pg is not not None and not equal to the current process group. First, convert the tensor as replicated among the TP process group. Second, reset the process group to the new pg. @@ -220,7 +232,7 @@ class ColoTensor(torch.Tensor): return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec)) def to_replicate_(self): - """to_replicate_ + """to_replicate_ an inline member function, converting dist spec of the tensor to REPLICATE """ diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index aee8b2799..9a3101e38 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -1,15 +1,17 @@ +from enum import Enum +from typing import Dict, Set, Tuple + import torch import torch.distributed as dist -from enum import Enum -from torch.optim import Optimizer from torch.nn import Parameter -from colossalai.nn.parallel.data_parallel import ZeroDDP -from typing import Dict, Tuple, Set +from torch.optim import Optimizer + from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.utils import get_current_device, disposable -from colossalai.gemini.chunk import Chunk, ChunkManager +from colossalai.nn.parallel.data_parallel import ZeroDDP +from colossalai.utils import disposable, get_current_device class OptimState(Enum): @@ -219,6 +221,8 @@ class ZeroOptimizer(ColossalaiOptimizer): def get_range_pair(local_chunk: Chunk, local_param: Parameter): param_info = local_chunk.tensors_info[local_param] + if local_chunk.keep_gathered: + return param_info.offset, param_info.end begin = max(0, param_info.offset - local_chunk.shard_begin) end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) return begin, end diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py index 57a49314f..3268b00a2 100644 --- a/tests/test_gemini/update/test_chunkv2.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -1,121 +1,124 @@ -import torch -import colossalai -import pytest -import torch.multiprocessing as mp -import torch.distributed as dist -from functools import partial -from colossalai.testing import rerun_if_address_is_in_use, parameterize -from colossalai.utils import free_port, get_current_device -from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.tensor import ColoParameter -from colossalai.gemini import TensorState -from colossalai.gemini.chunk import Chunk - - -def dist_sum(x): - temp = torch.tensor([x], device=get_current_device()) - dist.all_reduce(temp) - return temp.item() - - -def add_param(param_list, param_cp_list, *args, **kwargs): - param = ColoParameter(torch.randn(*args, **kwargs)) - param_list.append(param) - param_cp_list.append(param.clone()) - - -def check_euqal(param, param_cp): - if param.device != param_cp.device: - temp = param.data.to(param_cp.device) - else: - temp = param.data - return torch.equal(temp, param_cp.data) - - -@parameterize('init_device', [None, torch.device('cpu')]) -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) -def exam_chunk_basic(init_device, keep_gathered, pin_memory): - world_size = torch.distributed.get_world_size() - pg = ColoProcessGroup() - my_chunk = Chunk(chunk_size=1024, - process_group=pg, - dtype=torch.float32, - init_device=init_device, - keep_gathered=keep_gathered, - pin_memory=pin_memory) - - param_list = [] - param_cp_list = [] - - add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') - add_param(param_list, param_cp_list, 4, 4) - add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') - add_param(param_list, param_cp_list, 1, 1, 5) - - for param in param_list: - my_chunk.append_tensor(param) - assert my_chunk.utilized_size == 597 - for param, param_cp in zip(param_list, param_cp_list): - check_euqal(param, param_cp) - my_chunk.close_chunk() - - if keep_gathered is False: - assert my_chunk.cpu_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cpu' - assert my_chunk.can_move - my_chunk.shard_move(get_current_device()) - else: - assert my_chunk.chunk_total.size(0) == 1024 - assert my_chunk.device_type == 'cuda' - assert not my_chunk.can_move - - assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size - flag = my_chunk.has_inf_or_nan - assert not flag, "has_inf_or_nan is {}".format(flag) - - my_chunk.access_chunk() - assert my_chunk.device_type == 'cuda' - for param, param_cp in zip(param_list, param_cp_list): - check_euqal(param, param_cp) - - assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 - my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) - assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3 - assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1 - assert not my_chunk.can_release - - for param in param_list: - my_chunk.tensor_trans_state(param, TensorState.COMPUTE) - my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) - - assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4 - assert my_chunk.can_reduce - my_chunk.reduce() - assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 - - if keep_gathered is False: - assert my_chunk.cuda_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cuda' - assert my_chunk.can_move - else: - assert my_chunk.chunk_total.size(0) == 1024 - assert my_chunk.device_type == 'cuda' - assert not my_chunk.can_move - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_chunk_basic() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2, 4]) -@rerun_if_address_is_in_use() -def test_chunk_function(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_chunk_function(4) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini import TensorState +from colossalai.gemini.chunk import Chunk +from colossalai.tensor import ColoParameter +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device + + +def dist_sum(x): + temp = torch.tensor([x], device=get_current_device()) + dist.all_reduce(temp) + return temp.item() + + +def add_param(param_list, param_cp_list, *args, **kwargs): + param = ColoParameter(torch.randn(*args, **kwargs)) + param_list.append(param) + param_cp_list.append(param.clone()) + + +def check_euqal(param, param_cp): + if param.device != param_cp.device: + temp = param.data.to(param_cp.device) + else: + temp = param.data + return torch.equal(temp, param_cp.data) + + +@parameterize('init_device', [None, torch.device('cpu')]) +@parameterize('keep_gathered', [True, False]) +@parameterize('pin_memory', [True, False]) +def exam_chunk_basic(init_device, keep_gathered, pin_memory): + world_size = torch.distributed.get_world_size() + pg = ColoProcessGroup() + my_chunk = Chunk(chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + cpu_shard_init=True, + keep_gathered=keep_gathered, + pin_memory=pin_memory) + + param_list = [] + param_cp_list = [] + + add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') + add_param(param_list, param_cp_list, 4, 4) + add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') + add_param(param_list, param_cp_list, 1, 1, 5) + + for param in param_list: + my_chunk.append_tensor(param) + assert my_chunk.utilized_size == 597 + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + my_chunk.close_chunk() + + if keep_gathered is False: + assert my_chunk.cpu_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cpu' + assert my_chunk.can_move + my_chunk.shard_move(get_current_device()) + else: + assert my_chunk.chunk_total.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size + flag = my_chunk.has_inf_or_nan + assert not flag, "has_inf_or_nan is {}".format(flag) + + my_chunk.access_chunk() + assert my_chunk.device_type == 'cuda' + for param, param_cp in zip(param_list, param_cp_list): + check_euqal(param, param_cp) + + assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 + my_chunk.tensor_trans_state(param_list[0], TensorState.COMPUTE) + assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 3 + assert my_chunk.tensors_state_monitor[TensorState.COMPUTE] == 1 + assert not my_chunk.can_release + + for param in param_list: + my_chunk.tensor_trans_state(param, TensorState.COMPUTE) + my_chunk.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) + + assert my_chunk.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == 4 + assert my_chunk.can_reduce + my_chunk.reduce() + assert my_chunk.tensors_state_monitor[TensorState.HOLD] == 4 + + if keep_gathered is False: + assert my_chunk.cuda_shard.size(0) == 1024 // world_size + assert my_chunk.device_type == 'cuda' + assert my_chunk.can_move + else: + assert my_chunk.chunk_total.size(0) == 1024 + assert my_chunk.device_type == 'cuda' + assert not my_chunk.can_move + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_chunk_basic() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@rerun_if_address_is_in_use() +def test_chunk_function(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_function(4) diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index eb433f2c3..0a2db2a17 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -40,7 +40,8 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) -def exam_gpt_fwd_bwd(placement_policy): +@parameterize('keep_gather', [False, True]) +def exam_gpt_fwd_bwd(placement_policy, keep_gather): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -55,7 +56,7 @@ def exam_gpt_fwd_bwd(placement_policy): world_size = torch.distributed.get_world_size() config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False + config_dict[world_size]['keep_gathered'] = keep_gather chunk_manager = ChunkManager(config_dict) gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ZeroDDP(model, gemini_manager, pin_memory=True) @@ -101,4 +102,4 @@ def test_gpt(world_size): if __name__ == '__main__': - test_gpt(1) + test_gpt(4) diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index 62822f133..a7c2fc2b2 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam from colossalai.nn.parallel import ZeroDDP @@ -98,10 +98,55 @@ def exam_gpt_fwd_bwd(placement_policy): check_param(model, torch_model) +@parameterize('placement_policy', ['cuda', 'cpu']) +def exam_tiny_example(placement_policy): + set_seed(42) + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + torch_model = model_builder().cuda() + for torch_p, p in zip(torch_model.parameters(), model.parameters()): + torch_p.data.copy_(p.data) + + chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager, pin_memory=True) + + optimizer = HybridAdam(model.parameters(), lr=1e-3) + zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) + + amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) + torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) + + model.eval() + torch_model.eval() + + set_seed(dist.get_rank() * 3 + 128) + for i, (input_ids, attn_mask) in enumerate(train_dataloader): + if i > 2: + break + + zero_logits = run_fwd_bwd(model, criterion, zero_optim, input_ids, attn_mask) + torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) + assert torch.allclose(zero_logits, torch_logits, rtol=1e-3, atol=1e-2) + # debug_print([0], zero_logits, torch_logits) + + zero_optim.step() + torch_optim.step() + + check_param(model, torch_model) + + def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_gpt_fwd_bwd() + exam_tiny_example() @pytest.mark.dist @@ -113,4 +158,4 @@ def test_gpt(world_size): if __name__ == '__main__': - test_gpt(1) + test_gpt(2)