diff --git a/colossalai/elixir/README.md b/colossalai/elixir/README.md new file mode 100644 index 000000000..8adce38dc --- /dev/null +++ b/colossalai/elixir/README.md @@ -0,0 +1,71 @@ +# Elixir (Gemini2.0) +Elixir, also known as Gemini, is a technology designed to facilitate the training of large models on a small GPU cluster. +Its goal is to eliminate data redundancy and leverage CPU memory to accommodate really large models. +In addition, Elixir automatically profiles each training step prior to execution and selects the optimal configuration for the ratio of redundancy and the device for each parameter. +This repository is used to benchmark the performance of Elixir. +Elixir will be integrated into ColossalAI for usability. + +## Environment + +This version is a beta release, so the running environment is somewhat restrictive. +We are only demonstrating our running environment here, as we have not yet tested its compatibility. +We have set the CUDA version to `11.6` and the PyTorch version to `1.13.1+cu11.6`. + +## Examples + +Here is a simple example to wrap your model and optimizer for [fine-tuning](https://github.com/hpcaitech/Elixir/tree/main/example/fine-tune). + +```python +from elixir.search import minimum_waste_search +from elixir.wrapper import ElixirModule, ElixirOptimizer + +model = BertForSequenceClassification.from_pretrained('bert-base-uncased') +optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-8) + +sr = minimum_waste_search(model, world_size) +model = ElixirModule(model, sr, world_group) +optimizer = ElixirOptimizer(model, optimizer) +``` + +Here is an advanced example for performance, which is used in our [benchmark](https://github.com/hpcaitech/Elixir/blob/main/example/common/elx.py). + +```python +import torch +import torch.distributed as dist +from colossalai.nn.optimizer import HybridAdam +from elixir.wrapper import ElixirModule, ElixirOptimizer + +# get the world communication group +global_group = dist.GroupMember.WORLD +# get the communication world size +global_size = dist.get_world_size() + +# initialize the model in CPU +model = get_model(model_name) +# HybridAdam allows a part of parameters updated on CPU and a part updated on GPU +optimizer = HybridAdam(model.parameters(), lr=1e-3) + +sr = optimal_search( + model, + global_size, + unified_dtype=torch.float16, # enable for FP16 training + overlap=True, # enable for overlapping communications + verbose=True, # print detailed processing information + inp=data, # proivde an example input data in dictionary format + step_fn=train_step # provide an example step function +) +model = ElixirModule( + model, + sr, + global_group, + prefetch=True, # prefetch chunks to overlap communications + dtype=torch.float16, # use AMP + use_fused_kernels=True # enable fused kernels in Apex +) +optimizer = ElixirOptimizer( + model, + optimizer, + initial_scale=64, # loss scale used in AMP + init_step=True # enable for the stability of training +) +``` diff --git a/colossalai/elixir/__init__.py b/colossalai/elixir/__init__.py new file mode 100644 index 000000000..b7fd76a5d --- /dev/null +++ b/colossalai/elixir/__init__.py @@ -0,0 +1 @@ +from .wrapper import ElixirModule, ElixirOptimizer diff --git a/colossalai/elixir/chunk/__init__.py b/colossalai/elixir/chunk/__init__.py new file mode 100644 index 000000000..72d17dbc1 --- /dev/null +++ b/colossalai/elixir/chunk/__init__.py @@ -0,0 +1,2 @@ +from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState +from .fetcher import ChunkFetcher diff --git a/colossalai/elixir/chunk/core/__init__.py b/colossalai/elixir/chunk/core/__init__.py new file mode 100644 index 000000000..468d5428e --- /dev/null +++ b/colossalai/elixir/chunk/core/__init__.py @@ -0,0 +1,4 @@ +from .chunk import Chunk +from .group import ChunkGroup +from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock +from .states import TensorState diff --git a/colossalai/elixir/chunk/core/chunk.py b/colossalai/elixir/chunk/core/chunk.py new file mode 100644 index 000000000..cbf2fcd5e --- /dev/null +++ b/colossalai/elixir/chunk/core/chunk.py @@ -0,0 +1,567 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.tensor import FakeTensor + +from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock +from .states import TensorState, ts_update_sanity_check + + +class ChunkFullError(Exception): + pass + + +@dataclass +class TensorInfo: + state: TensorState + fake_data: FakeTensor + offset: int + end: int + + +class Chunk: + """Chunk is a type of data structure to store tensors. + It allows us to store a sequence of tensors into one continuous memory block. + Moreover, Chunk manages the storage of tensors in a distributed way. + Normally, a chunk is scattered across its process group. + When a tensor in this chunk should be used later, the chunk can be gathered by access_chunk. + When the training is done, the chunk can be scattered by reduce_chunk. + + args: + rcache: the memory pool to store replicated chunks + chunk_size: the size of the chunk + chunk_dtype: the dtype of the chunk + process_group: the torch communication group of the chunk + temp_device: the device to store the temporary chunk when initializing + shard_device: the device to store the shard of the scattered chunk + rcache_fused: whether this chunk is fused in rcache without eviction + cpu_pin_memory: whether this chunk use cpu pin memory for its shard + """ + total_count = 0 + + def __init__( + self, + rcache: MemoryPool, + chunk_size: int, + chunk_dtype: torch.dtype, + process_group: ProcessGroup, + temp_device: Optional[torch.device] = None, + shard_device: Optional[torch.device] = None, + rcache_fused: bool = False, # whether this chunk is used in ZeRO2 + cpu_pin_memory: bool = False # whether this chunk has a permanent copy in cpu + ) -> None: + + self.chunk_id: int = Chunk.total_count + Chunk.total_count += 1 + # set replicated cache pool + self.rcache: MemoryPool = rcache + + self.chunk_size: int = chunk_size + self.chunk_dtype: torch.dtype = chunk_dtype + self.utilized_size: int = 0 + + self.torch_pg: ProcessGroup = process_group + self.pg_size: int = dist.get_world_size(self.torch_pg) + self.pg_rank: int = dist.get_rank(self.torch_pg) + + # the chunk size should be divisible by the dp degree + assert chunk_size % self.pg_size == 0 + + self.shard_size: int = chunk_size // self.pg_size + self.shard_begin: int = self.shard_size * self.pg_rank + self.shard_end: int = self.shard_begin + self.shard_size + self.valid_end: int = self.shard_size + 1 # set to an illegal number + + # notice: release blocks reserved by Pytorch + torch.cuda.empty_cache() + # rcache block, the global replicated chunk in R cache + self.rcb: Optional[TensorBlock] = None + self.rcache_fused: bool = rcache_fused + self._my_block = None + self.is_replica: bool = True + # allocate a private block for fused chunks + if self.rcache_fused: + self._my_block = rcache.get_private_block(chunk_size, chunk_dtype) + + temp_device: torch.device = temp_device or gpu_device() + # chunk_temp is a global chunk, which only exists during building the chunks. + # keep all elements to zero + self.chunk_temp: Optional[torch.Tensor] = None + if rcache_fused: + self.chunk_temp = self._my_block.payload + torch.zero_(self.chunk_temp) + else: + self.chunk_temp = torch.zeros(chunk_size, dtype=chunk_dtype, device=temp_device) + + # configure the init device of the shard + # no-offload default: fp16, fp32 -> CUDA + # offload default: fp16, fp32 -> CPU + shard_device: torch.device = shard_device or torch.device('cpu') + pin_flag: bool = cpu_pin_memory and shard_device.type == 'cpu' + # chunk.shard is a local chunk + # it is desinged to exist permanently + self.shard: torch.Tensor = torch.empty(self.shard_size, + dtype=chunk_dtype, + device=shard_device, + pin_memory=pin_flag) + + # calculate the memory occupation of the chunk and the shard + self.chunk_memo: int = self.chunk_size * self.chunk_temp.element_size() + self.shard_memo: int = self.chunk_memo // self.pg_size + + # each tensor is associated with a TensorInfo to track its meta info + # (state, shape, offset, end) + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + # the total number of tensors in the chunk + self.num_tensors: int = 0 + + # Record the number of tensors in different states + self.tensor_state_cnter: Dict[TensorState, int] = dict() + for state in TensorState: + self.tensor_state_cnter[state] = 0 + + # we introduce the paired chunk here + # it refers to another chunk having the same parameters + # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk + self.paired_chunk = None + # if this chunk is synchronized with the optimizer, the flag is True + self.optim_sync_flag = True + + # whether to record l2 norm for the gradient clipping calculation + self.l2_norm_flag = False + self.l2_norm = None + # whether it overflows after the reduction + self.overflow = False + + @property + def prepared_block(self): + return self._my_block + + @property + def is_init(self): + return self.chunk_temp is not None + + @property + def in_rcache(self): + return self.rcb is not None + + @property + def shard_device(self): + return self.shard.device + + @property + def memory_usage(self) -> Dict[str, int]: + cuda_memory = 0 + cpu_memory = 0 + + # this chunk is not closed + if self.is_init: + if self.chunk_temp.device.type == 'cuda': + cuda_memory += self.chunk_memo + else: + cpu_memory += self.chunk_memo + + # this chunk is on the rcache + if self.in_rcache: + cuda_memory += self.rcb.memo_occ + + # calculate the occupation of the chunk shard + if self.shard_device.type == 'cuda': + cuda_memory += self.shard_memo + elif self.shard_device.type == 'cpu': + cpu_memory += self.shard_memo + else: + raise NotImplementedError + + return dict(cuda=cuda_memory, cpu=cpu_memory) + + @property + def payload(self) -> torch.Tensor: + if self.is_init: + return self.chunk_temp + + if self.in_rcache: + return self.rcb.payload + else: + return self.shard + + @property + def shard_move_check(self) -> bool: + return not self.in_rcache + + def _not_compute_number(self): + total = 0 + state_list = [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE] + for state in state_list: + total += self.tensor_state_cnter[state] + return total + + @property + def scatter_check(self) -> bool: + if self.rcache_fused: + return False + return self._not_compute_number() == self.num_tensors + + @property + def reduce_check(self): + return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors + + def set_overflow_flag(self, valid_tensor: torch.Tensor) -> None: + assert not self.overflow + self.overflow = torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() + + def set_l2_norm(self, valid_tensor: torch.Tensor) -> None: + assert self.l2_norm is None, 'you are calculating the l2 norm twice' + chunk_l2_norm = valid_tensor.data.float().norm(2) + self.l2_norm = chunk_l2_norm.item()**2 + + def append_tensor(self, tensor: torch.Tensor): + # sanity check + assert self.is_init + assert tensor.dtype == self.chunk_dtype + + new_utilized_size = self.utilized_size + tensor.numel() + # raise exception when the chunk size is exceeded + if new_utilized_size > self.chunk_size: + raise ChunkFullError + + self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + fake_data = FakeTensor(tensor.data) + + # record all the information about the tensor + self.num_tensors += 1 + tensor_state = TensorState.HOLD + self.tensor_state_cnter[tensor_state] += 1 + + self.tensors_info[tensor] = TensorInfo(state=tensor_state, + fake_data=fake_data, + offset=self.utilized_size, + end=new_utilized_size) + + self.utilized_size = new_utilized_size + + def close_chunk(self): + # sanity check + assert self.is_init + + # calculate the valid end for each shard + if self.utilized_size <= self.shard_begin: + self.valid_end = 0 + elif self.utilized_size < self.shard_end: + self.valid_end = self.utilized_size - self.shard_begin + + self.__remove_tensors_ptr() + self.__update_shard(self.chunk_temp, self.shard) + self.is_replica = False + self.chunk_temp = None + + def replicate(self): + assert not self.is_replica + + self.is_replica = True + this_shard = self.shard if self.optim_sync_flag else self.__paired_shard() + self.__update_replica(self.rcb.payload, this_shard) + self.__update_tensors_ptr() + + def scatter(self): + assert not self.rcache_fused + assert self.is_replica + + self.__remove_tensors_ptr() + if not self.optim_sync_flag: + self.__update_shard(self.rcb.payload, self.shard) + self.optim_sync_flag = True + self.is_replica = False + + def reduce(self, always_fp32: bool = False): + assert self.is_replica + + self.__remove_tensors_ptr() + + if self.pg_size > 1: + cast_to_fp32 = False + if always_fp32 and self.chunk_dtype != torch.float: + cast_to_fp32 = True + # cast the payload to fp32 + reduce_buffer = self.rcb.payload.to(dtype=torch.float) + else: + # otherwise, use the same payload + reduce_buffer = self.rcb.payload + + # divide the reduce buffer by the size of the process group + reduce_buffer /= self.pg_size + # try to use inplace reduce scatter + # notice: pytorch does not allow true inplace reduce scatter + # because pytorch will allocate a continuous memory space for collective communications + shard_buffer = reduce_buffer[self.shard_begin:self.shard_end] + dist.reduce_scatter_tensor(shard_buffer, reduce_buffer, group=self.torch_pg) + + # the result should be moved to payload for norm calculating + if cast_to_fp32: + calc_buffer = self.rcb.payload[self.shard_begin:self.shard_end] + calc_buffer.copy_(shard_buffer) + else: + # if process group size equals to 1, do not communicate + reduce_buffer = self.rcb.payload + + self.__update_shard(reduce_buffer, self.shard) + + self.is_replica = False + + def access_chunk(self, block: Optional[TensorBlock] = None): + # sanity check + assert not self.is_init + assert not self.is_replica + + if self.rcache_fused: + assert block is None + self.rcb = self._my_block + else: + assert block in self.rcache.public_used_blocks + assert self.rcb is None + self.rcb = block + + self.replicate() + + def release_chunk(self) -> TensorBlock: + # sanity check + assert not self.is_init + assert self.is_replica + + if self.rcache_fused: + raise RuntimeError + + self.scatter() + block = self.rcb + self.rcb = None + return block + + def update_extra_reduce_info(self, block: Optional[TensorBlock]): + if self.rcache_fused: + assert block is None + block = self._my_block + else: + assert block is not None + + buffer = block.payload[self.shard_begin:self.shard_end] + valid_tensor = buffer[:self.valid_end] + self.set_overflow_flag(valid_tensor) + if self.l2_norm_flag: + self.set_l2_norm(valid_tensor) + + def reduce_chunk(self, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]: + """Reduce scatter all the gradients. It's an operation done in CUDA. + """ + # sanity check + assert not self.is_init + assert self.is_replica + + self.reduce(always_fp32=always_fp32) + self.__update_tensors_state(TensorState.HOLD) + + # reset the rcb pointer + block = self.rcb + self.rcb = None + if self.rcache_fused: + block = None + + if sync: + self.update_extra_reduce_info(block) + + return block + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + prev_state = self.tensors_info[tensor].state + if prev_state == tensor_state: + return + if ts_update_sanity_check(prev_state, tensor_state): + self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + # sanity check + assert self.is_replica + + info = self.tensors_info[tensor] + payload = self.rcb.payload + payload[info.offset:info.end].copy_(data_slice.data.flatten()) + tensor.data = payload[info.offset:info.end].view(tensor.shape) + + def init_pair(self, friend_chunk: 'Chunk') -> None: + if self.paired_chunk is None and friend_chunk.paired_chunk is None: + self.paired_chunk = friend_chunk + friend_chunk.paired_chunk = self + else: + assert self.paired_chunk is friend_chunk + assert friend_chunk.paired_chunk is self + + def optim_update(self) -> None: + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. + """ + # sanity check + assert self.paired_chunk is not None + friend_chunk: Chunk = self.paired_chunk + assert not friend_chunk.is_replica + # gradient and optimizer should be on the same device + assert self.shard_device.type == friend_chunk.shard_device.type + + if self.shard_device.type == 'cuda': + self.shard.copy_(friend_chunk.shard) + self.optim_sync_flag = True + elif self.shard_device.type == 'cpu': + # optim_sync_flag is set to False + # see shard_move function for more details + self.optim_sync_flag = False + else: + raise NotImplementedError + + def get_tensors(self) -> List[torch.Tensor]: + return list(self.tensors_info.keys()) + + def get_cpu_copy(self, only_rank_0: bool = False) -> List[torch.Tensor]: + assert not self.is_init + + if self.is_replica: + # use the payload directly when being replica + temp_buffer = self.rcb.payload + else: + # otherwise, create a temporary buffer + temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device()) + # cheat the assertion in __update_replica + self.is_replica = True + self.__update_replica(temp_buffer, self.shard) + self.is_replica = False + + cpu_copys = [None] * self.num_tensors + if not only_rank_0 or self.pg_rank == 0: + for i, (t, info) in enumerate(self.tensors_info.items()): + t_copy = temp_buffer[info.offset:info.end].view(t.shape).cpu() + cpu_copys[i] = t_copy + # synchronize + dist.barrier() + return cpu_copys + + def load_tensors(self, tensor_list: List[Optional[torch.Tensor]], only_rank_0: bool = False) -> bool: + assert not self.is_replica + assert not self.is_init + temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device()) + # cheat the assertion in __update_replica + self.is_replica = True + self.__update_replica(temp_buffer, self.shard) + self.is_replica = False + + if not only_rank_0 or self.pg_rank == 0: + for (_, c_info), load_tensor in zip(self.tensors_info.items(), tensor_list): + if load_tensor is None: + continue + temp_buffer[c_info.offset:c_info.end].copy_(load_tensor.data.flatten()) + + # synchronize + dist.barrier() + + if only_rank_0: + dist.broadcast(temp_buffer, src=0, group=self.torch_pg) + + # cheat the assertion in __update_shard + self.is_replica = True + self.__update_shard(temp_buffer, self.shard) + self.is_replica = False + + def __update_replica(self, replica: torch.Tensor, shard: torch.Tensor): + assert self.is_replica + assert replica.numel() == self.chunk_size + assert shard.numel() == self.shard_size + + buffer = replica[self.shard_begin:self.shard_end] + buffer.copy_(shard) + dist.all_gather_into_tensor(replica, buffer, group=self.torch_pg) + + def __update_shard(self, replica: torch.Tensor, shard: torch.Tensor): + assert self.is_replica + assert replica.numel() == self.chunk_size + assert shard.numel() == self.shard_size + + shard.copy_(replica[self.shard_begin:self.shard_end]) + + def __paired_shard(self): + assert self.paired_chunk is not None, 'chunks should be paired before training' + optim_chunk: Chunk = self.paired_chunk + assert self.chunk_size == optim_chunk.chunk_size + + # only be called when optimizer state is in CPU memory + # the grad and param should be in the same device + assert self.shard_device.type == 'cpu' + return optim_chunk.shard.to(gpu_device()) + + def __remove_tensors_ptr(self) -> None: + # sanity check + # each tensor should point to its fake data before scatter + assert self.is_replica + for tensor, info in self.tensors_info.items(): + tensor.data = info.fake_data + + def __update_tensors_ptr(self) -> None: + # sanity check + # the chunk should be replicated to get the correct pointer + assert self.is_replica + payload = self.rcb.payload + for tensor, info in self.tensors_info.items(): + tensor.data = payload[info.offset:info.end].view(tensor.shape) + + def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): + self.tensor_state_cnter[tensor_info.state] -= 1 + tensor_info.state = next_state + self.tensor_state_cnter[tensor_info.state] += 1 + + def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + self.__update_one_tensor_info(tensor_info, next_state) + + def __hash__(self) -> int: + return self.chunk_id + + def __lt__(self, other: object) -> bool: + return self.chunk_id < other.chunk_id + + def __eq__(self, other: object) -> bool: + return self.chunk_id == other.chunk_id + + def __repr__(self, detailed: bool = True): + if self.is_init: + state = 'initialization' + elif self.in_rcache: + state = 'replicated' + else: + state = 'scattered' + + output = [ + f'Chunk {self.chunk_id} details: state -> {state}\n', + f' length: {self.chunk_size}, dtype: {self.chunk_dtype}, group_size: {self.pg_size}, tensors: {self.num_tensors}\n' + f' utilized size: {self.utilized_size}, utilized percentage: {100 * (self.utilized_size / self.chunk_size):.0f}%\n' + ] + + memory_info = self.memory_usage + output.append(' memory usage: (cuda -> {}, cpu -> {})\n'.format(memory_info['cuda'], memory_info['cpu'])) + + def print_tensor(name, tensor, prefix=''): + output.append(f'{prefix}{name}: (shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device})\n') + + if self.is_init: + print_tensor(name='temp', tensor=self.chunk_temp, prefix=' ') + if self.in_rcache: + print_tensor(name='block', tensor=self.rcb.payload, prefix=' ') + if self.shard is not None: + print_tensor(name='shard', tensor=self.shard, prefix=' ') + + if detailed: + output.append(' tensor state monitor:\n') + for st in TensorState: + output.append(' # of {}: {}\n'.format(st, self.tensor_state_cnter[st])) + + return ''.join(output) diff --git a/colossalai/elixir/chunk/core/group.py b/colossalai/elixir/chunk/core/group.py new file mode 100644 index 000000000..495040e51 --- /dev/null +++ b/colossalai/elixir/chunk/core/group.py @@ -0,0 +1,180 @@ +from typing import Dict, List, Optional, Set + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from .chunk import Chunk +from .memory_pool import MemoryPool, TensorBlock +from .states import TensorState + + +class ChunkGroup(object): + """ChunkGroup manages a group of chunks and their memory pool. + Commonly, one model has one chunk group. + It supports chunk allocation, chunk access, and chunk release. + ChunkGroup is responsible for the memory management before its APIs. + + args: + rcache: A memory pool to instantiate chunks. + """ + + def __init__(self, rcache: MemoryPool) -> None: + super().__init__() + self.rcache = rcache + self.fused_chunks: Set[Chunk] = set() + self.float_chunks: Set[Chunk] = set() + self.ten_to_chunk: Dict[torch.Tensor, Chunk] = dict() + + self.accessed_fused_chunks: Set[Chunk] = set() + self.accessed_float_chunks: Set[Chunk] = set() + + def __add_to_accset(self, chunk: Chunk): + if chunk.rcache_fused: + self.accessed_fused_chunks.add(chunk) + else: + self.accessed_float_chunks.add(chunk) + + def __remove_from_accset(self, chunk: Chunk): + if chunk.rcache_fused: + self.accessed_fused_chunks.remove(chunk) + else: + self.accessed_float_chunks.remove(chunk) + + def __check_new_float_chunk(self, size: int, dtype: torch.dtype): + # if the public space is 0, there is no access operations + if self.rcache.public_space == 0: + return + # otherwise, check its size and dtype + assert size == self.rcache.public_block_size + assert dtype == self.rcache.public_dtype + + def inside_check(self, chunk: Chunk) -> None: + """Check whether the chunk is in this ChunkGroup""" + if chunk.rcache_fused: + assert chunk in self.fused_chunks + else: + assert chunk in self.float_chunks + + def is_accessed(self, chunk: Chunk) -> bool: + """Chech whether the chunk is accessed.""" + # sanity check + self.inside_check(chunk) + + if chunk.rcache_fused: + return (chunk in self.accessed_fused_chunks) + else: + return (chunk in self.accessed_float_chunks) + + def open_chunk(self, + chunk_size: int, + chunk_dtype: torch.dtype, + process_group: ProcessGroup, + chunk_config: Optional[Dict] = None) -> Chunk: + """Open a chunk to store parameters.""" + if chunk_config is None: + chunk_config = {} + + chunk = Chunk(rcache=self.rcache, + chunk_size=chunk_size, + chunk_dtype=chunk_dtype, + process_group=process_group, + **chunk_config) + # sanity check + if not chunk.rcache_fused: + self.__check_new_float_chunk(chunk_size, chunk_dtype) + + return chunk + + def close_chunk(self, chunk: Chunk) -> bool: + """Close the chunk during the allocation.""" + chunk.close_chunk() + # add the new chunk to the set of allocated chunks + if chunk.rcache_fused: + self.fused_chunks.add(chunk) + else: + self.float_chunks.add(chunk) + # add the new chunk to the mapping + for t in chunk.get_tensors(): + assert t not in self.ten_to_chunk + self.ten_to_chunk[t] = chunk + return True + + def allocate_chunk(self, + tensor_list: List[torch.Tensor], + chunk_size: int, + chunk_dtype: torch.dtype, + process_group: ProcessGroup, + chunk_config: Optional[Dict] = None) -> Chunk: + """Allocate a chunk for a list of parameters.""" + chunk = self.open_chunk(chunk_size, chunk_dtype, process_group, chunk_config) + # append tensors + for t in tensor_list: + chunk.append_tensor(t) + self.close_chunk(chunk) + + return chunk + + def tensors_to_chunks(self, tensor_list: List[torch.Tensor]) -> List[Chunk]: + """Get the chunks of a gevien list of tensors.""" + chunk_list = list() + for tensor in tensor_list: + chunk = self.ten_to_chunk.get(tensor) + if chunk not in chunk_list: + chunk_list.append(chunk) + chunk_list.sort(key=lambda c: c.chunk_id) + return chunk_list + + def rcache_enough_check(self, chunk: Chunk) -> bool: + """Check whether the rcache has enough blocks to store the gathered chunk.""" + if chunk.rcache_fused: + return True + return self.rcache.public_free_cnt > 0 + + def access_chunk(self, chunk: Chunk) -> bool: + """Access a chunk into rCache.""" + self.inside_check(chunk) + # if this chunk is accessed already, return False + if self.is_accessed(chunk): + return False + + if chunk.rcache_fused: + block = None + else: + block = self.rcache.get_public_block() + chunk.access_chunk(block) + self.__add_to_accset(chunk) + return True + + def release_chunk(self, chunk: Chunk) -> bool: + """Release a chunk from rCache.""" + self.inside_check(chunk) + assert self.is_accessed(chunk) + assert chunk.scatter_check + + block = chunk.release_chunk() + if block: + self.rcache.free_public_block(block) + self.__remove_from_accset(chunk) + return True + + def reduce_chunk(self, chunk: Chunk, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]: + """Reduce and scatter a gradient chunk from rCache.""" + self.inside_check(chunk) + assert self.is_accessed(chunk) + assert chunk.reduce_check + + block = chunk.reduce_chunk(always_fp32=always_fp32, sync=sync) + if block and sync: + # if synchronized, free the block into rcache + self.rcache.free_public_block(block) + block = None + + self.__remove_from_accset(chunk) + + return block + + def tensor_trans_state(self, tensor: torch.Tensor, state: TensorState): + """Transform the state of a tensor.""" + chunk = self.ten_to_chunk.get(tensor) + chunk.tensor_trans_state(tensor, state) diff --git a/colossalai/elixir/chunk/core/memory_pool.py b/colossalai/elixir/chunk/core/memory_pool.py new file mode 100644 index 000000000..e73fc65a6 --- /dev/null +++ b/colossalai/elixir/chunk/core/memory_pool.py @@ -0,0 +1,172 @@ +from abc import ABC +from collections import defaultdict +from typing import Iterable, NamedTuple + +import torch +from torch.autograd.profiler_util import _format_memory + + +class BlockRequire(NamedTuple): + numel: int + dtype: torch.dtype + + +class TensorBlock(ABC): + """TensorBlock is the memory unit of memory pool. + It is a continuous memory block used to store tensors. + Each chunk needs a corresponding TensorBlock to store its data during training. + + args: + numel: the number of elements in the block + dtype: the data type of the block + device_type: the device type of the block + """ + total_count: int = 0 + + def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: + self.block_id = TensorBlock.total_count + TensorBlock.total_count += 1 + + self.device_type = device_type + self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type) + self.memo_occ: int = self.payload.numel() * self.payload.element_size() + + @property + def numel(self): + return self.payload.numel() + + @property + def dtype(self): + return self.payload.dtype + + @property + def device(self): + return self.payload.device + + def __hash__(self) -> int: + return self.block_id + + def __eq__(self, other: object) -> bool: + return self.block_id == other.block_id + + def __repr__(self) -> str: + return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})' + + +class PublicBlock(TensorBlock): + """Public blocks have the same length. + Chunks of the same length can share the same public block. + """ + + def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: + super().__init__(numel, dtype, device_type) + self.block_type = 'public' + + def __repr__(self) -> str: + return f'PublicBlock{super().__repr__()}' + + +class PrivateBlock(TensorBlock): + """Private blocks may have different lengths. + Each private chunk should use its own private block. + """ + + def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: + super().__init__(numel, dtype, device_type) + self.block_type = 'private' + + def __repr__(self) -> str: + return f'PrivateBlock{super().__repr__()}' + + +class MemoryPool(object): + """A memory pool consists of public blocks and private blocks. + rCache uses memory pool to manage memory bolcks. + Users should allocate memory blocks before using it. + + args: + device_type: the device type of the memory pool + """ + + def __init__(self, device_type: str) -> None: + self.device_type: str = device_type + + self.public_space: int = 0 + self.public_block_size: int = 0 + self.public_dtype: torch.dtype = None + + self.public_free_blocks: list = None + self.public_used_blocks: set = None + + self.public_free_cnt: int = 0 + self.public_used_cnt: int = 0 + + self.private_space: int = 0 + self.private_blocks: list = None + self.private_lookup_dict: dict[BlockRequire, list] = None + + self.__allocate_flag = False + + def allocate(self, + public_dtype: torch.dtype = torch.float, + public_block_size: int = 1024, + public_block_number: int = 0, + private_block_list: Iterable[BlockRequire] = ()): + assert self.__allocate_flag is False + assert public_block_number >= 0 + + self.public_free_blocks = list() + self.public_used_blocks = set() + for _ in range(public_block_number): + block = PublicBlock(public_block_size, public_dtype, self.device_type) + self.public_free_blocks.append(block) + + if public_block_number <= 0: + self.public_space = 0 + else: + self.public_space = self.public_free_blocks[0].memo_occ * public_block_number + self.public_block_size = public_block_size + self.public_dtype = public_dtype + + self.public_free_cnt = public_block_number + self.public_used_cnt = 0 + + self.private_space = 0 + self.private_blocks = list() + self.private_lookup_dict = defaultdict(list) + + for require in private_block_list: + block = PrivateBlock(require.numel, require.dtype, self.device_type) + self.private_space += block.memo_occ + self.private_blocks.append(block) + self.private_lookup_dict[require].append(block) + + self.__allocate_flag = True + + def __repr__(self) -> str: + return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})' + + def get_private_block(self, numel: int, dtype: torch.dtype): + block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype)) + return block_list.pop() + + def get_public_block(self): + self.public_free_cnt -= 1 + self.public_used_cnt += 1 + + block = self.public_free_blocks.pop() + self.public_used_blocks.add(block) + + return block + + def free_public_block(self, block: TensorBlock): + assert isinstance(block, PublicBlock) + assert block in self.public_used_blocks + + self.public_free_cnt += 1 + self.public_used_cnt -= 1 + + self.public_used_blocks.remove(block) + self.public_free_blocks.append(block) + + return block diff --git a/colossalai/elixir/chunk/core/states.py b/colossalai/elixir/chunk/core/states.py new file mode 100644 index 000000000..90d4c9260 --- /dev/null +++ b/colossalai/elixir/chunk/core/states.py @@ -0,0 +1,25 @@ +from enum import Enum + + +class TensorState(Enum): + FREE = 0 + COMPUTE = 1 + HOLD = 2 + HOLD_AFTER_BWD = 3 + READY_FOR_REDUCE = 4 + + +# expected: free -> hold -> compute -> hold -> +# -> compute -> hold_after_bwd -> ready_for_reduce +legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), + (TensorState.READY_FOR_REDUCE, TensorState.HOLD)] + + +def ts_update_sanity_check(old_state, new_state) -> bool: + if (old_state, new_state) not in legal_ts_update_list: + raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}') + return True diff --git a/colossalai/elixir/chunk/fetcher.py b/colossalai/elixir/chunk/fetcher.py new file mode 100644 index 000000000..3b7c69e99 --- /dev/null +++ b/colossalai/elixir/chunk/fetcher.py @@ -0,0 +1,210 @@ +from typing import List, Optional + +import torch + +from .core import Chunk, ChunkGroup, TensorBlock, TensorState +from .scheduler import ChunkScheduler + + +class ChunkFetcher(object): + """ChunkFetcher is responsible for fetching and reducing chunks during training. + Any operations on chunks should be done through ChunkFetcher. + + args: + scheduler: A ChunkScheduler to schedule evictable chunks. + group: A ChunkGroup to manage chunks. + overlap: Whether to overlap communications. + reduce_always_fp32: Whether to reduce gradients in FP32. + """ + + def __init__(self, + scheduler: ChunkScheduler, + group: ChunkGroup, + overlap: bool = False, + reduce_always_fp32: bool = False) -> None: + + self.scheduler: ChunkScheduler = scheduler + self.group: ChunkGroup = group + self.reduce_always_fp32 = reduce_always_fp32 + self.current_step = -1 + + self.overlap_flag = overlap + self.main_stream = torch.cuda.current_stream() + + self.predict_next_chunk: Optional[Chunk] = None + self.is_fetching: bool = False + self.prefetch_stream = torch.cuda.Stream() + + self.reduced_chunk: Optional[Chunk] = None + self.reduced_block: Optional[TensorBlock] = None + self.reduce_stream = torch.cuda.Stream() + + def reset(self): + """Reset the fetcher to the initial state. + Users should call this function before training.""" + self.scheduler.reset() + self.current_step = -1 + + def clear(self): + """Clear the fetcher. + Users should call this function after training.""" + if self.overlap_flag: + torch.cuda.synchronize() + self.predict_next_chunk = None + self.is_fetching = False + if self.reduced_chunk is not None: + self.reduce_call_back() + self.reduced_chunk = None + self.reduced_block = None + + self.scheduler.clear() + + def trans_to_compute(self, tensors: List[torch.Tensor]): + """Transform tensors to COMPUTE state. + This function should be called before the compute operators.""" + # update tensor states + for t in tensors: + self.group.tensor_trans_state(t, TensorState.COMPUTE) + # chunk operations + chunks = self.group.tensors_to_chunks(tensors) + for chunk in chunks: + self.scheduler.remove(chunk) + return chunks + + def trans_to_hold(self, tensors: List[torch.Tensor], phase: str): + """Transform tensors to HOLD state. + This function should be called after the compute operators.""" + assert phase in ('f', 'b') + next_state = TensorState.HOLD if phase == 'f' else TensorState.HOLD_AFTER_BWD + # update tensor states + for t in tensors: + self.group.tensor_trans_state(t, next_state) + # chunk operations + chunks = self.group.tensors_to_chunks(tensors) + for chunk in chunks: + if chunk.scatter_check: + self.scheduler.add(chunk) + + def get_one_chunk(self, tensor: torch.Tensor) -> Chunk: + """Get the chunk of the given tensor.""" + return self.group.ten_to_chunk.get(tensor) + + def get_chunks(self, tensors: List[torch.Tensor]) -> List[Chunk]: + """Get the chunks of the given tensors.""" + return self.group.tensors_to_chunks(tensors) + + def is_in_fused(self, tensor: torch.Tensor): + """Check whether the given tensor is in a fused chunk.""" + chunk = self.get_one_chunk(tensor) + return chunk.rcache_fused + + def filter_chunks(self, chunks: List[Chunk]): + """Filter the accessed chunks, since they are already on the rCache.""" + return list(filter(lambda c: not self.group.is_accessed(c), chunks)) + + def fetch_chunks(self, chunks: List[Chunk]): + """Fetch chunks needed for this compute operator. + The chunks should be in the COMPUTE state first.""" + # make step + 1 + self.step() + + predict_hit = False + # try to prefetch the next chunk + if self.predict_next_chunk is not None and self.predict_next_chunk in chunks: + if self.is_fetching: + # prefetch hit, wait async prefetch + self.main_stream.wait_stream(self.prefetch_stream) + # clear prefetch information + self.predict_next_chunk = None + self.is_fetching = False + + predict_hit = True + # filter accessed chunks + scattered = self.filter_chunks(chunks) + # sanity check: upload should wait for prefetch + if self.predict_next_chunk is not None: + assert len(scattered) == 0 + # all chunks are on the rcache + if len(scattered) == 0: + # prefetch if there is a hit above + if predict_hit: + self.prefetch(chunks) + return + + for chunk in scattered: + # if the rcache is not enough, just release a chunk + if not self.group.rcache_enough_check(chunk): + maybe_chunk = self.scheduler.top() + # print(f'Evicting {chunk.chunk_id} -> {maybe_chunk.chunk_id}') + if maybe_chunk is None: + raise RuntimeError('R cache is not enough. Try to allocate more.') + self.scheduler.remove(maybe_chunk) + self.group.release_chunk(maybe_chunk) + + # print('Accessing', chunk.chunk_id) + self.group.access_chunk(chunk) + + if self.overlap_flag: + assert self.predict_next_chunk is None + self.prefetch(chunks) + + def reduce_call_back(self): + self.reduced_chunk.update_extra_reduce_info(self.reduced_block) + if self.reduced_block is not None: + self.group.rcache.free_public_block(self.reduced_block) + + def reduce_chunk(self, chunk: Chunk): + """Reduce and scatter the given gradient chunk.""" + if not chunk.reduce_check: + return False + + self.scheduler.remove(chunk) + + if not self.overlap_flag: + # reduce the chunk if not overlapped + self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=True) + else: + # wait main stream for its computation + self.reduce_stream.wait_stream(self.main_stream) + # main stream should wait reduce stream + # if there is a block recycle + if self.reduced_chunk is not None: + self.main_stream.wait_stream(self.reduce_stream) + self.reduce_call_back() + + with torch.cuda.stream(self.reduce_stream): + self.reduced_chunk = chunk + self.reduced_block = self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=False) + + def prefetch(self, chunks: List[Chunk]): + """Prefetch the next used chunk.""" + next_chunk = self.scheduler.get_next_chunk(chunks) + self.predict_next_chunk = next_chunk + + # return if there is no next scattered chunk + if next_chunk is None or self.group.is_accessed(next_chunk): + return + + evict_chunk = None + if not self.group.rcache_enough_check(next_chunk): + maybe_chunk = self.scheduler.top() + # if there is no chunk can be evicted, just return + if maybe_chunk is None: + return + # otherwise, release this chunk + self.scheduler.remove(maybe_chunk) + evict_chunk = maybe_chunk + + with torch.cuda.stream(self.prefetch_stream): + # wait main stream + self.prefetch_stream.wait_stream(self.main_stream) + self.is_fetching = True + + if evict_chunk is not None: + self.group.release_chunk(evict_chunk) + self.group.access_chunk(next_chunk) + + def step(self): + """Update the scheduler.""" + self.scheduler.step() + self.current_step += 1 diff --git a/colossalai/elixir/chunk/scheduler/__init__.py b/colossalai/elixir/chunk/scheduler/__init__.py new file mode 100644 index 000000000..ce2cc27c9 --- /dev/null +++ b/colossalai/elixir/chunk/scheduler/__init__.py @@ -0,0 +1,3 @@ +from .base import ChunkScheduler +from .fifo import FIFOScheduler +from .prefetch import PrefetchScheduler diff --git a/colossalai/elixir/chunk/scheduler/base.py b/colossalai/elixir/chunk/scheduler/base.py new file mode 100644 index 000000000..bb2122f2f --- /dev/null +++ b/colossalai/elixir/chunk/scheduler/base.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from colossalai.elixir.chunk.core import Chunk + + +class ChunkScheduler(ABC): + """The base class of all chunk schedulers. + A chunk scherduler stores all releasable chunks. + It provides APIs to add, remove, display releasable chunks. + """ + + def __init__(self) -> None: + super().__init__() + self.releasable_set: Optional[set] = None + self.current_step = -1 + + @abstractmethod + def reset(self) -> None: + self.releasable_set = set() + self.current_step = -1 + + @abstractmethod + def clear(self) -> None: + # asure the set is empty now + assert not bool(self.releasable_set) + + @abstractmethod + def top(self) -> Optional[Chunk]: + # return None if the releasable set is empty + if not self.releasable_set: + return False + return True + + @abstractmethod + def add(self, chunk: Chunk) -> bool: + if chunk in self.releasable_set: + return False + self.releasable_set.add(chunk) + return True + + @abstractmethod + def remove(self, chunk: Chunk) -> bool: + if chunk not in self.releasable_set: + return False + self.releasable_set.remove(chunk) + return True + + def step(self, *args, **kwags): + self.current_step += 1 diff --git a/colossalai/elixir/chunk/scheduler/fifo.py b/colossalai/elixir/chunk/scheduler/fifo.py new file mode 100644 index 000000000..af7f54d49 --- /dev/null +++ b/colossalai/elixir/chunk/scheduler/fifo.py @@ -0,0 +1,40 @@ +from typing import Optional + +from .base import Chunk, ChunkScheduler + + +class FIFOScheduler(ChunkScheduler): + """The FIFO chunk scheduler. + It stores all releasable chunks in a FIFO queue. + """ + + def __init__(self) -> None: + super().__init__() + self.fifo_dict: Optional[dict] = None + + def reset(self) -> None: + super().reset() + self.fifo_dict = dict() + + def clear(self) -> None: + super().clear() + self.fifo_dict = None + + def top(self) -> Optional[Chunk]: + if not super().top(): + return None + dict_iter = iter(self.fifo_dict) + ret = next(dict_iter) + return ret + + def add(self, chunk: Chunk) -> bool: + if not super().add(chunk): + return False + self.fifo_dict[chunk] = True + return True + + def remove(self, chunk: Chunk) -> bool: + if not super().remove(chunk): + return False + self.fifo_dict.pop(chunk) + return True diff --git a/colossalai/elixir/chunk/scheduler/prefetch.py b/colossalai/elixir/chunk/scheduler/prefetch.py new file mode 100644 index 000000000..e1f6ce8d7 --- /dev/null +++ b/colossalai/elixir/chunk/scheduler/prefetch.py @@ -0,0 +1,85 @@ +from collections import defaultdict +from typing import Iterable, List, Optional + +import torch +from sortedcontainers import SortedSet + +from .base import Chunk, ChunkScheduler + + +class PrefetchScheduler(ChunkScheduler): + """The prefetch chunk scheduler. + Its top functions gives the furthest used chunk. + """ + + def __init__(self, chunk_called_per_step: List[Iterable[Chunk]]) -> None: + super().__init__() + self.chunk_mapping = None + self.evict_set = None + self.search_step = -1 + + self.chunks_per_step = chunk_called_per_step + self.total_steps = len(chunk_called_per_step) + self.next_step_dict = defaultdict(list) + # initialize the next_step dictionary + for i, c_list in enumerate(chunk_called_per_step): + for c in c_list: + self.next_step_dict[c].append(i) + + def _get_next_step(self, chunk: Chunk): + step_list = self.next_step_dict[chunk] + for i in step_list: + if i > self.current_step: + return i + return self.total_steps + + def reset(self) -> None: + super().reset() + self.chunk_mapping = dict() + self.evict_set = SortedSet() + self.search_step = -1 + + def clear(self) -> None: + super().clear() + if torch.is_grad_enabled(): + assert self.current_step == self.total_steps - 1 + self.chunk_mapping = None + self.evict_set = None + self.search_step = -1 + + def top(self) -> Optional[Chunk]: + if not super().top(): + return None + next_step, chunk = self.evict_set[-1] + return chunk + + def add(self, chunk: Chunk) -> bool: + if not super().add(chunk): + return False + value = (self._get_next_step(chunk), chunk) + self.chunk_mapping[chunk] = value + self.evict_set.add(value) + return True + + def remove(self, chunk: Chunk) -> bool: + if not super().remove(chunk): + return False + value = self.chunk_mapping[chunk] + self.evict_set.remove(value) + self.chunk_mapping.pop(chunk) + return True + + def step(self, *args, **kwags): + super().step(*args, **kwags) + if self.current_step >= self.total_steps: + raise RuntimeError('exceed simulated steps, please modify your profiling `step_fn`') + + def get_next_chunk(self, chunks: List[Chunk]): + self.search_step = max(self.search_step, self.current_step + 1) + while self.search_step < self.total_steps: + c_list = self.chunks_per_step[self.search_step] + for c in c_list: + if c not in chunks: + return c + self.search_step += 1 + return None diff --git a/colossalai/elixir/ctx/__init__.py b/colossalai/elixir/ctx/__init__.py new file mode 100644 index 000000000..6c56ea0c5 --- /dev/null +++ b/colossalai/elixir/ctx/__init__.py @@ -0,0 +1,32 @@ +import torch + +tensor_creation_methods = dict(tensor=torch.tensor, + sparse_coo_tensor=torch.sparse_coo_tensor, + asarray=torch.asarray, + as_tensor=torch.as_tensor, + as_strided=torch.as_strided, + from_numpy=torch.from_numpy, + from_dlpack=torch.from_dlpack, + frombuffer=torch.frombuffer, + zeros=torch.zeros, + zeros_like=torch.zeros_like, + ones=torch.ones, + ones_like=torch.ones_like, + arange=torch.arange, + range=torch.range, + linspace=torch.linspace, + logspace=torch.logspace, + eye=torch.eye, + empty=torch.empty, + empty_like=torch.empty_like, + empty_strided=torch.empty_strided, + full=torch.full, + full_like=torch.full_like, + quantize_per_tensor=torch.quantize_per_tensor, + quantize_per_channel=torch.quantize_per_channel, + dequantize=torch.dequantize, + complex=torch.complex, + polar=torch.polar, + heaviside=torch.heaviside) + +from .meta_ctx import MetaContext diff --git a/colossalai/elixir/ctx/meta_ctx.py b/colossalai/elixir/ctx/meta_ctx.py new file mode 100644 index 000000000..7710a5971 --- /dev/null +++ b/colossalai/elixir/ctx/meta_ctx.py @@ -0,0 +1,34 @@ +import torch + +from colossalai.elixir.ctx import tensor_creation_methods + + +class MetaContext(object): + """A context manager that wraps all tensor creation methods in torch. + By default, all tensors will be created in meta. + + args: + device_type: The device type of the tensors to be created. + """ + + def __init__(self, device_type: str = 'meta') -> None: + super().__init__() + self.device_type = device_type + return None + + def __enter__(self): + + def meta_wrap(func): + + def wrapped_func(*args, **kwargs): + kwargs['device'] = self.device_type + return func(*args, **kwargs) + + return wrapped_func + + for name, method in tensor_creation_methods.items(): + setattr(torch, name, meta_wrap(method)) + + def __exit__(self, exc_type, exc_val, exc_tb): + for name, method in tensor_creation_methods.items(): + setattr(torch, name, method) diff --git a/colossalai/elixir/cuda.py b/colossalai/elixir/cuda.py new file mode 100644 index 000000000..33f63aa52 --- /dev/null +++ b/colossalai/elixir/cuda.py @@ -0,0 +1,28 @@ +import functools + +import torch +from torch.cuda._utils import _get_device_index + +elixir_cuda_fraction = dict() + + +@functools.lru_cache() +def gpu_device(): + return torch.device(torch.cuda.current_device()) + + +def set_memory_fraction(fraction, device=None): + torch.cuda.set_per_process_memory_fraction(fraction, device) + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + elixir_cuda_fraction[device] = fraction + + +def get_allowed_memory(device=None): + total_memory = torch.cuda.get_device_properties(device).total_memory + if device is None: + device = torch.cuda.current_device() + device = _get_device_index(device) + fraction = elixir_cuda_fraction.get(device, 1.0) + return int(fraction * total_memory) diff --git a/colossalai/elixir/hook/__init__.py b/colossalai/elixir/hook/__init__.py new file mode 100644 index 000000000..4e88b76e5 --- /dev/null +++ b/colossalai/elixir/hook/__init__.py @@ -0,0 +1,2 @@ +from .parameter import HookParam +from .storage import BufferStore diff --git a/colossalai/elixir/hook/functions.py b/colossalai/elixir/hook/functions.py new file mode 100644 index 000000000..c90527de1 --- /dev/null +++ b/colossalai/elixir/hook/functions.py @@ -0,0 +1,73 @@ +import torch + +from colossalai.elixir.chunk import ChunkFetcher + +from .storage import BufferStore + + +def prefwd_postbwd_function(fetcher: ChunkFetcher, store: BufferStore): + + class PreFwdPostBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, *args): + with torch._C.DisableTorchFunction(): + ctx.params = params + chunks = fetcher.trans_to_compute(params) + fetcher.fetch_chunks(chunks) + + offset = 0 + for p in ctx.params: + if not fetcher.is_in_fused(p): + # we should add parameters to buffer + # because their blocks may be changed + offset = store.insert(p, offset) + + return args + + @staticmethod + def backward(ctx, *grads): + with torch._C.DisableTorchFunction(): + fetcher.trans_to_hold(ctx.params, phase='b') + + for p in ctx.params: + if not fetcher.is_in_fused(p): + store.erase(p) + + return (None, *grads) + + return PreFwdPostBwd.apply + + +def postfwd_prebwd_function(fetcher: ChunkFetcher, store: BufferStore): + + class PostFwdPreBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, *args): + with torch._C.DisableTorchFunction(): + ctx.params = params + + fetcher.trans_to_hold(ctx.params, phase='f') + for p in ctx.params: + if not fetcher.is_in_fused(p): + store.erase(p) + + return args + + @staticmethod + def backward(ctx, *grads): + with torch._C.DisableTorchFunction(): + chunks = fetcher.trans_to_compute(ctx.params) + fetcher.fetch_chunks(chunks) + + offset = 0 + for p in ctx.params: + if not fetcher.is_in_fused(p): + # we should add parameters to buffer + # because their blocks may be changed + offset = store.insert(p, offset) + + return (None, *grads) + + return PostFwdPreBwd.apply diff --git a/colossalai/elixir/hook/parameter.py b/colossalai/elixir/hook/parameter.py new file mode 100644 index 000000000..17cb62855 --- /dev/null +++ b/colossalai/elixir/hook/parameter.py @@ -0,0 +1,102 @@ +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.utils._pytree import tree_map + +from colossalai.elixir.chunk import ChunkFetcher +from colossalai.elixir.kernels import fused_torch_functions +from colossalai.elixir.tensor import OutplaceTensor, is_no_hook_op, to_outplace_tensor + +from .functions import postfwd_prebwd_function, prefwd_postbwd_function +from .storage import BufferStore + + +class HookParam(OutplaceTensor, nn.Parameter): + """HookParam is a special type of tensor that is used to triggered hooks on parameters. + HookParam adds chunk fetching before torch functions. + """ + pre_fwd_func = None + post_fwd_func = None + use_fused_kernel = False + + @staticmethod + def attach_fetcher(fetcher: ChunkFetcher, store: BufferStore): + HookParam.pre_fwd_func = prefwd_postbwd_function(fetcher, store) + HookParam.post_fwd_func = postfwd_prebwd_function(fetcher, store) + + @staticmethod + def release_fetcher(): + HookParam.pre_fwd_func = None + HookParam.post_fwd_func = None + + @staticmethod + def enable_fused_kernel(): + HookParam.use_fused_kernel = True + + @staticmethod + def disable_fused_kernel(): + HookParam.use_fused_kernel = False + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if is_no_hook_op(func): + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + return ret + + params_to_index = OrderedDict() + params_index = 0 + + def append_param(x): + nonlocal params_index + if isinstance(x, HookParam): + params_to_index[x] = params_index + params_index += 1 + + tree_map(append_param, args) + tree_map(append_param, kwargs) + + params = tuple(params_to_index.keys()) + new_params = HookParam.pre_fwd_func(params, *params) + + def replace_param(x): + if isinstance(x, HookParam): + return new_params[params_to_index[x]] + return x + + with torch._C.DisableTorchFunction(): + if HookParam.use_fused_kernel and func in fused_torch_functions: + func = fused_torch_functions.get(func) + ret = func(*tree_map(replace_param, args), **tree_map(replace_param, kwargs)) + if not isinstance(ret, tuple): + ret = (ret,) + + ptr_set = set() + for p in new_params: + ptr_set.add(p.data_ptr()) + + def clone_inplace_tensor(x): + if isinstance(x, torch.Tensor): + start_point = x.data_ptr() - x.element_size() * x.storage_offset() + if start_point in ptr_set: + return x.clone() + return x + + ret = tree_map(clone_inplace_tensor, ret) + ret = HookParam.post_fwd_func(params, *ret) + + def convert(t): + if isinstance(t, torch.Tensor): + t = to_outplace_tensor(t) + return t + + ret = tree_map(convert, ret) + + if len(ret) == 1: + return ret[0] + else: + return ret diff --git a/colossalai/elixir/hook/storage.py b/colossalai/elixir/hook/storage.py new file mode 100644 index 000000000..5d0ded0fd --- /dev/null +++ b/colossalai/elixir/hook/storage.py @@ -0,0 +1,56 @@ +import torch +from torch.autograd.profiler_util import _format_memory + +from colossalai.elixir.cuda import gpu_device + + +class BufferStore(object): + """A place to store parameters temporarily when computing. + Parameters should be stored into the buffer before computaions. + + args: + buffer_size: The size of the buffer. + buffer_dtype: The dtype of the buffer. + device_str: The device to store the buffer. + """ + + def __init__(self, buffer_size: torch.Tensor, buffer_dtype: torch.dtype, device_str: str = 'cuda') -> None: + super().__init__() + self.buffer_size = buffer_size + self.buffer_dtype = buffer_dtype + self.buffer: torch.Tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device_str) + self.buffer_occ = buffer_size * self.buffer.element_size() + self.record_dict = dict() + + def zeros(self): + torch.zero_(self.buffer) + + def insert(self, t: torch.Tensor, offset: int) -> int: + assert t not in self.record_dict + end = offset + t.numel() + assert end <= self.buffer_size, f'buffer size is {self.buffer_size} but needs {end}' + + new_data = self.buffer[offset:end].view(t.shape) + new_data.copy_(t.data) + + self.record_dict[t] = t.data + t.data = new_data + + return end + + def erase(self, t: torch.Tensor): + assert t in self.record_dict + + new_data = self.record_dict.pop(t) + t.data = new_data + + return + + def empty_like(self, t: torch.Tensor): + return self.buffer[:t.numel()].view(t.shape) + + def empty_1d(self, size: int): + return self.buffer[:size] + + def __repr__(self) -> str: + return f'Buffer(size={self.buffer_size}, dtype={self.buffer_dtype}, memo_occ={_format_memory(self.buffer_occ)})' diff --git a/colossalai/elixir/kernels/__init__.py b/colossalai/elixir/kernels/__init__.py new file mode 100644 index 000000000..3ca4a2614 --- /dev/null +++ b/colossalai/elixir/kernels/__init__.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + +fused_torch_functions = {F.layer_norm: F.layer_norm} + + +def register_fused_layer_norm(): + try: + from .layernorm import ln_func + fused_torch_functions[F.layer_norm] = ln_func + print('Register fused layer norm successfully from apex.') + except: + print('Cannot import fused layer norm, please install apex from source.') + pass + + +register_fused_layer_norm() diff --git a/colossalai/elixir/kernels/attention.py b/colossalai/elixir/kernels/attention.py new file mode 100644 index 000000000..ca977c5a5 --- /dev/null +++ b/colossalai/elixir/kernels/attention.py @@ -0,0 +1,43 @@ +import torch +import xformers.ops as xops +from torch.utils._pytree import tree_map + +from colossalai.elixir.tracer.memory_tracer.memory_tensor import MTensor + + +def lower_triangular_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, p: float = 0.0): + + args = (query, key, value) + meta_flag = False + + for x in args: + if x.device.type == 'meta': + meta_flag = True + break + + if meta_flag: + atten = query @ key.transpose(-2, -1) + output = atten @ value + return output + + profile_flag = False + + def to_torch_tensor(x): + if isinstance(x, MTensor): + nonlocal profile_flag + profile_flag = True + return x.elem + return x + + args = tree_map(to_torch_tensor, args) + query, key, value = args + output = xops.memory_efficient_attention(query=query, + key=key, + value=value, + p=p, + attn_bias=xops.LowerTriangularMask()) + + if profile_flag: + output = MTensor(output) + + return output diff --git a/colossalai/elixir/kernels/attn_wrapper.py b/colossalai/elixir/kernels/attn_wrapper.py new file mode 100644 index 000000000..58388bee6 --- /dev/null +++ b/colossalai/elixir/kernels/attn_wrapper.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model +from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder + +from .gpt_attention import XGPT2Attention, XGPT2Model +from .opt_attention import XOPTAttention, XOPTDecoder + + +def wrap_attention(model: nn.Module): + for name, module in model.named_modules(): + if isinstance(module, GPT2Model): + module.__class__ = XGPT2Model + elif isinstance(module, GPT2Attention): + module.__class__ = XGPT2Attention + elif isinstance(module, OPTAttention): + module.__class__ = XOPTAttention + elif isinstance(module, OPTDecoder): + module.__class__ = XOPTDecoder + return model diff --git a/colossalai/elixir/kernels/gpt_attention.py b/colossalai/elixir/kernels/gpt_attention.py new file mode 100644 index 000000000..8204e14f5 --- /dev/null +++ b/colossalai/elixir/kernels/gpt_attention.py @@ -0,0 +1,45 @@ +import einops +import torch +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model + +from .attention import lower_triangular_attention + + +class XGPT2Attention(GPT2Attention): + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + assert self.scale_attn_weights + assert not self.is_cross_attention + assert not self.scale_attn_by_inverse_layer_idx + assert not self.reorder_and_upcast_attn + + b_size, h_size, m_size, k_size = query.size() + + assert self.bias.size(-1) == m_size + query = einops.rearrange(query, 'b h m k -> b m h k') + key = einops.rearrange(key, 'b h m k -> b m h k') + value = einops.rearrange(value, 'b h m k -> b m h k') + + drop_rate = self.attn_dropout.p + output = lower_triangular_attention(query, key, value, p=drop_rate) + + ret = einops.rearrange(output, 'b m h k -> b h m k') + + return ret, None + + +class XGPT2Model(GPT2Model): + + def forward(self, *args, **kwargs): + assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg' + attn_mask = kwargs.get('attention_mask') + # assert torch.all(attn_mask == 1), 'only accept no padding mask' + + head_mask = kwargs.get('head_mask', None) + assert head_mask is None, 'head mask should be None' + + output_attn = kwargs.get('output_attentions', False) + if output_attn: + Warning('output_attentions is not supported for XGPT2Model') + + return super().forward(*args, **kwargs) diff --git a/colossalai/elixir/kernels/layernorm.py b/colossalai/elixir/kernels/layernorm.py new file mode 100644 index 000000000..67837d4c1 --- /dev/null +++ b/colossalai/elixir/kernels/layernorm.py @@ -0,0 +1,10 @@ +from apex.normalization.fused_layer_norm import fused_layer_norm, fused_layer_norm_affine + + +def ln_func(input, normalized_shape, weight=None, bias=None, eps=1e-05): + if weight is None: + assert bias is None + return fused_layer_norm(input, normalized_shape, eps) + else: + assert weight is not None and bias is not None + return fused_layer_norm_affine(input, weight, bias, normalized_shape, eps) diff --git a/colossalai/elixir/kernels/opt_attention.py b/colossalai/elixir/kernels/opt_attention.py new file mode 100644 index 000000000..6049807df --- /dev/null +++ b/colossalai/elixir/kernels/opt_attention.py @@ -0,0 +1,86 @@ +from typing import Optional, Tuple + +import einops +import torch +import torch.nn as nn +from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder + +from .attention import lower_triangular_attention + + +class XOPTAttention(OPTAttention): + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + assert is_cross_attention is False + assert past_key_value is None + assert layer_head_mask is None + # assert output_attentions is False + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim) + # get key, value proj + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + src_len = key_states.size(1) + assert tgt_len == src_len + + attn_output = lower_triangular_attention(query=query_states, key=key_states, value=value_states, p=self.dropout) + + if attn_output.size() != (bsz, tgt_len, self.num_heads, self.head_dim): + raise ValueError(f'`attn_output` should be of size {(bsz, tgt_len, self.num_heads, self.head_dim)}, but is' + f' {attn_output.size()}') + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +class XOPTDecoder(OPTDecoder): + + def forward(self, *args, **kwargs): + assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg' + attn_mask = kwargs.get('attention_mask') + # assert torch.all(attn_mask == 1), 'only accept no padding mask' + + head_mask = kwargs.get('head_mask', None) + assert head_mask is None, 'head mask should be None' + + output_attn = kwargs.get('output_attentions', False) + if output_attn: + Warning('output_attentions is not supported for XOPTDecoder') + + return super().forward(*args, **kwargs) diff --git a/colossalai/elixir/search/__init__.py b/colossalai/elixir/search/__init__.py new file mode 100644 index 000000000..40ba043f1 --- /dev/null +++ b/colossalai/elixir/search/__init__.py @@ -0,0 +1,4 @@ +from .mini_waste import minimum_waste_search +from .optimal import optimal_search +from .result import ChunkPlan, SearchResult +from .simple import simple_search diff --git a/colossalai/elixir/search/base.py b/colossalai/elixir/search/base.py new file mode 100644 index 000000000..50baa27fa --- /dev/null +++ b/colossalai/elixir/search/base.py @@ -0,0 +1,135 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import List, Tuple + +import torch +import torch.nn as nn + +from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool +from colossalai.elixir.tracer.param_tracer import generate_tf_order +from colossalai.elixir.tracer.utils import meta_copy +from colossalai.elixir.utils import print_rank_0 + +from .result import ChunkPlan +from .utils import to_meta_tensor + + +class SearchBase(ABC): + """A basic class for search algorithms. + + args: + module: the model to be searched + dtype: the unified dtype of all parameters + prefetch: whether to prefetch chunks during training + verbose: whether to print search details + inp: a dictionary, the example input of the model + step_fn: the example step function of the model + """ + + def __init__(self, + module: nn.Module, + dtype: torch.dtype = torch.float, + prefetch: bool = False, + verbose: bool = False, + inp=None, + step_fn=None) -> None: + + self.unified_dtype = dtype + self.meta_module = meta_copy(module, partial(to_meta_tensor, dtype=self.unified_dtype)) + self.prefetch_flag = prefetch + self.verbose = verbose + self.param_to_name = {param: name for name, param in self.meta_module.named_parameters()} + + self.public_block_size = 1024 + self.public_block_number = 0 + + self.param_per_step = None + self.max_checkpoint_size = 0 + if self.prefetch_flag: + assert inp is not None and step_fn is not None + tf_running_info = generate_tf_order(self.meta_module, inp, step_fn, dtype) + self.param_per_step = tf_running_info.get('params_per_step') + if self.verbose: + print_rank_0('Prefetch enabled: the called order of parameters') + for i, step in enumerate(self.param_per_step): + print_rank_0(f'step {i}: {step}') + + name_to_param = {name: param for name, param in self.meta_module.named_parameters()} + for checkpoint in tf_running_info.get('checkpoint_info'): + sum_numel = 0 + for i in range(*checkpoint): + for name in self.param_per_step[i]: + param = name_to_param[name] + sum_numel += param.numel() + self.max_checkpoint_size = max(self.max_checkpoint_size, sum_numel) + if self.verbose: + print_rank_0(f'checkpoint infomation: from-to -> {checkpoint}, numel -> {sum_numel}') + + @abstractmethod + def private_truncate(self, param: nn.Parameter) -> int: + """A function used to truncate the length of a private chunk, + which only contains one parameter. + """ + pass + + @abstractmethod + def public_trucate(self, length: int) -> int: + """A function used to trucate the length of all publick chunks + """ + pass + + @abstractmethod + def search(self, *args, **kwargs) -> Tuple: + """The core search function. It returns a tuple of a private group and public groups. + """ + pass + + def generate_chunk_plans(self, private_group, publick_groups) -> List[ChunkPlan]: + plans = list() + for param in private_group: + chunk_size = self.private_truncate(param) + chunk_dtype = param.dtype + chunk_kwargs = dict(rcache_fused=True) + chunk_plan = ChunkPlan(name_list=[self.param_to_name[param]], + chunk_size=chunk_size, + chunk_dtype=chunk_dtype, + kwargs=chunk_kwargs) + plans.append(chunk_plan) + + self.public_block_size = self.public_trucate(self.public_block_size) + public_chunk_size = self.public_block_size + public_chunk_dtype = self.unified_dtype + for group in publick_groups: + chunk_kwargs = {} + chunk_plan = ChunkPlan(name_list=[self.param_to_name[p] for p in group], + chunk_size=public_chunk_size, + chunk_dtype=public_chunk_dtype, + kwargs=chunk_kwargs) + plans.append(chunk_plan) + + if self.verbose: + print_rank_0(f'Chunk plans: total {len(plans)} chunks') + for i, plan in enumerate(plans): + print_rank_0(f'plan {i}: {plan}') + + return plans + + def allocate_chunk_group(self, chunk_plans: List[ChunkPlan]) -> ChunkGroup: + block_require_list = list() + for plan in chunk_plans: + kwargs = plan.kwargs + if kwargs.get('rcache_fused', False): + block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype)) + + mp = MemoryPool('cuda') + mp.allocate(public_dtype=self.unified_dtype, + public_block_size=self.public_block_size, + public_block_number=self.public_block_number, + private_block_list=block_require_list) + + if self.verbose: + print_rank_0( + f'Memory pool (rcache): {mp}\n\tblock size -> {mp.public_block_size}, block number -> {mp.public_free_cnt}' + ) + + return ChunkGroup(mp) diff --git a/colossalai/elixir/search/mini_waste.py b/colossalai/elixir/search/mini_waste.py new file mode 100644 index 000000000..1cd1af7be --- /dev/null +++ b/colossalai/elixir/search/mini_waste.py @@ -0,0 +1,167 @@ +import math +from typing import Tuple + +import torch +import torch.nn as nn + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.utils import print_rank_0 + +from .base import SearchBase +from .result import SearchResult +from .utils import find_minimum_waste_size, find_search_range, get_multi_used_params, to_divide + +dtype_to_es = {torch.float16: 2, torch.float32: 4, torch.float64: 8} + + +class SearchMiniWaste(SearchBase): + """Search the best chunk size to minimize the waste of memory. + + args: + module: the module to be searched + default_group_size: the default group size of communications + dtype: the data type of the parameters + prefetch: whether to prefetch the parameters + verbose: whether to print the search details + inp: a dictionary, the example input of the model + step_fn: the example step function of training + """ + + def __init__(self, + module: nn.Module, + default_group_size: int, + dtype: torch.dtype = torch.float, + prefetch: bool = False, + verbose: bool = False, + inp=None, + step_fn=None) -> None: + + super().__init__(module, dtype, prefetch, verbose, inp, step_fn) + self.default_group_size = default_group_size + + def private_truncate(self, param: nn.Parameter) -> int: + return to_divide(param.numel(), self.default_group_size) + + def public_trucate(self, length: int) -> int: + return to_divide(length, self.default_group_size) + + def search(self) -> Tuple: + min_chunk_size, max_chunk_size, search_interval = find_search_range(self.meta_module) + # get multi-used parameters + private_params = get_multi_used_params(self.meta_module) + # get parameters used only one time + public_params = [p for p in self.meta_module.parameters() if p not in private_params] + # collect the number of elements of each parameter + public_numels = [p.numel() for p in public_params] + # calculate the sumary of all parameters + total_size = sum(public_numels) + + if total_size <= min_chunk_size: + public_block_size = total_size + waste_size = 0 + else: + public_block_size, waste_size = find_minimum_waste_size( + # pre-commit: do not rearrange + numel_group_list=[public_numels], + min_range=min_chunk_size, + max_range=max_chunk_size, + interval=search_interval) + + if self.verbose: + if total_size == 0: + waste_percentage = 0 + else: + waste_percentage = 100 * waste_size / total_size + print_rank_0( + f'Minimum waste search result: chunk size = {public_block_size}, waste percentage = {waste_percentage: .1f} %' + ) + + # initialize the mapping from parameters to chunks + param_to_chunk_id = dict() + chunk_id = 0 + # deal with private parameters + for p in private_params: + param_to_chunk_id[p] = chunk_id + chunk_id += 1 + # record the upper bound + private_id_upperbound = chunk_id + # deal with public parameters + last_left = 0 + for p in public_params: + p_size = p.numel() + + if last_left < p_size: + last_left = public_block_size + chunk_id += 1 + + assert last_left >= p_size + + last_left -= p_size + param_to_chunk_id[p] = chunk_id + + # initailize public groups + public_number_chunks = chunk_id - private_id_upperbound + public_groups = [[] for _ in range(public_number_chunks)] + for p in public_params: + public_chunk_id = param_to_chunk_id[p] - private_id_upperbound - 1 + public_groups[public_chunk_id].append(p) + + # calculate the number of minimum chunks allocated in R cache + max_lived_chunks = 0 + for module in self.meta_module.modules(): + param_set = set() + for param in module.parameters(recurse=False): + param_set.add(param_to_chunk_id[param]) + max_lived_chunks = max(max_lived_chunks, len(param_set)) + # allocate more chunks for prefetch + if self.prefetch_flag: + max_lived_chunks = min(max_lived_chunks + 4, public_number_chunks) + + if total_size == 0: + max_lived_chunks = 0 + + self.public_block_size = public_block_size + self.public_block_number = max_lived_chunks + + return (private_params, public_groups) + + +def minimum_waste_search(m: nn.Module, + group_size: int, + unified_dtype: torch.dtype = torch.float, + cpu_offload: bool = False, + prefetch: bool = False, + verbose: bool = False, + pin_memory: bool = True, + inp=None, + step_fn=None) -> SearchResult: + + search_class = SearchMiniWaste( + # pre-commit: do not rearrange + module=m, + default_group_size=group_size, + dtype=unified_dtype, + prefetch=prefetch, + verbose=verbose, + inp=inp, + step_fn=step_fn) + + private_group, public_groups = search_class.search() + chunk_plans = search_class.generate_chunk_plans(private_group, public_groups) + + # assign shard device + if cpu_offload: + shard_device = torch.device('cpu') + else: + shard_device = gpu_device() + + for plan in chunk_plans: + plan.kwargs['shard_device'] = shard_device + if cpu_offload: + plan.kwargs['cpu_pin_memory'] = pin_memory + + chunk_group = search_class.allocate_chunk_group(chunk_plans) + + return SearchResult(chunk_group=chunk_group, + chunk_plans=chunk_plans, + param_called_per_step=search_class.param_per_step) diff --git a/colossalai/elixir/search/optimal.py b/colossalai/elixir/search/optimal.py new file mode 100644 index 000000000..31cc7e84d --- /dev/null +++ b/colossalai/elixir/search/optimal.py @@ -0,0 +1,284 @@ +import math +from typing import Tuple + +import torch +import torch.nn as nn +from torch.autograd.profiler_util import _format_memory + +from colossalai.elixir.cuda import get_allowed_memory, gpu_device +from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling +from colossalai.elixir.utils import calc_buffer_size, print_rank_0 + +from .base import SearchBase +from .result import SearchResult +from .simulator import find_optimal_chunk_size, rcache_prioirity_check +from .utils import find_search_range, get_multi_used_params, to_divide + +dtype_to_es = {torch.float16: 2, torch.float32: 4, torch.float64: 8} + + +class SearchOptimal(SearchBase): + """Search the best chunk size to maximize the training throughput. + Users should provide the example input data and step function of training. + + args: + module: the module to be searched + default_group_size: the default group size of communications + activation_fragment_factor: the factor to estimate the total activation memory usage + allocation_fragment_factor: the factor to estimate the effective ratio of the memory usage can be used by Elixir + driver_usage: the memory usage of the cuda driver + dtype: the data type of the parameters + verbose: whether to print the search details + overlap: whether to overlap the communication and computation + inp: a dictionary, the example input of the model + step_fn: the example step function of training + """ + + def __init__(self, + module: nn.Module, + default_group_size: int, + activation_fragment_factor: float = 1.25, + allocation_fragment_factor: float = 0.95, + driver_usage: float = 2 * 1024**3, + dtype: torch.dtype = torch.float, + verbose: bool = False, + overlap: bool = False, + inp=None, + step_fn=None) -> None: + # as for optimal search, we must profile the model first + super().__init__(module, dtype, True, verbose, inp, step_fn) + # profile cuda memory usage + memo_usage = cuda_memory_profiling(model=self.meta_module, inp=inp, step_fn=step_fn, dtype=dtype) + torch.cuda.empty_cache() + buffer_occ = memo_usage['buffer_occ'] + # get the maximum memory usage of activation + predict_activation = memo_usage['activation_occ'] + # calculate the total capacity of the current device + gpu_memory = get_allowed_memory() + # allowed capacity = allocation_fragment_factor * (total capacity - activation_fragment_factor * activation) + self.cuda_capacity = int( + allocation_fragment_factor * + (gpu_memory - driver_usage - buffer_occ - activation_fragment_factor * predict_activation)) + hook_buffer_store_size = calc_buffer_size(m=self.meta_module, test_dtype=self.unified_dtype) + self.cuda_elements = self.cuda_capacity // dtype_to_es.get(dtype) - hook_buffer_store_size + + if self.cuda_elements < 0: + raise RuntimeError('optimal search: activation is too large, please reduce batch size') + + if self.verbose: + print_rank_0('Predict memory usage:') + for k, v in memo_usage.items(): + print_rank_0(f'{k}: {_format_memory(v)}') + print_rank_0(f'allowed allocation space: {_format_memory(self.cuda_capacity)}') + print_rank_0(f'hook buffer store size: {hook_buffer_store_size}') + print_rank_0(f'allowed {dtype} elements: {self.cuda_elements}') + + self.default_group_size = default_group_size + self.comm_overlap = overlap + + def private_truncate(self, param: nn.Parameter) -> int: + return to_divide(param.numel(), self.default_group_size) + + def public_trucate(self, length: int) -> int: + return to_divide(length, self.default_group_size) + + def search(self) -> Tuple: + min_chunk_size, max_chunk_size, search_interval = find_search_range(self.meta_module) + # get multi-used parameters + private_params = get_multi_used_params(self.meta_module) + # subtract the footprint of fused parameters + for param in private_params: + self.cuda_elements -= param.numel() + if self.cuda_elements < 0: + raise RuntimeError('optimal search: no enough space for fused parameters') + + # initialize public params in the called order + public_params = list() + public_param_set = set() + name_to_param = {name: param for name, param in self.meta_module.named_parameters()} + for name_set in self.param_per_step: + for name in name_set: + param = name_to_param.get(name) + if param in private_params or param in public_param_set: + continue + public_params.append(param) + public_param_set.add(param) + del name_to_param + del public_param_set + + # collect the number of elements of each parameter + public_numels = [p.numel() for p in public_params] + # calculate the sumary of all parameters + total_size = sum(public_numels) + # collect the name for each public parameters + public_param_names = [self.param_to_name[p] for p in public_params] + + if total_size <= min_chunk_size: + public_block_size = total_size + n_blocks = 1 + waste_size = 0 + else: + public_block_size, n_blocks, waste_size = find_optimal_chunk_size( + # pre-commit: do not rearrange + param_per_step=self.param_per_step, + param_names=public_param_names, + param_numels=public_numels, + cuda_elements=self.cuda_elements, + overlap=self.comm_overlap, + min_range=min_chunk_size, + max_range=max_chunk_size, + interval=search_interval) + # truncate the size of public blocks + public_block_size = self.public_trucate(public_block_size) + if self.cuda_elements < n_blocks * public_block_size: + raise RuntimeError('no enough space for unfused parameters') + + if self.verbose: + if total_size == 0: + waste_percentage = 0 + else: + waste_percentage = 100 * waste_size / total_size + print_rank_0( + f'Optimal search result: chunk size = {public_block_size}, waste percentage = {waste_percentage: .1f} %' + ) + + # initialize the mapping from parameters to chunks + param_to_chunk_id = dict() + chunk_id = 0 + # deal with private parameters + for p in private_params: + param_to_chunk_id[p] = chunk_id + chunk_id += 1 + # record the upper bound + private_id_upperbound = chunk_id + # deal with public parameters + last_left = 0 + for p in public_params: + p_size = p.numel() + + if last_left < p_size: + last_left = public_block_size + chunk_id += 1 + + assert last_left >= p_size + + last_left -= p_size + param_to_chunk_id[p] = chunk_id + + # initailize public groups + public_number_chunks = chunk_id - private_id_upperbound + public_groups = [[] for _ in range(public_number_chunks)] + for p in public_params: + public_chunk_id = param_to_chunk_id[p] - private_id_upperbound - 1 + public_groups[public_chunk_id].append(p) + + if total_size == 0: + n_blocks = 0 + + self.public_block_size = public_block_size + self.public_block_number = n_blocks + + return (private_params, public_groups) + + def configure_rcache_size(self, chunk_plans: list, os_factor: int): + element_os = 4 + if self.unified_dtype == torch.float16: + element_pa = 2 + elif self.unified_dtype == torch.float: + element_pa = 4 + else: + raise NotImplementedError + + priority = rcache_prioirity_check(n=self.default_group_size, r_os=os_factor, e_p=element_pa, e_o=element_os) + if self.verbose: + print_rank_0(f'rCache Priority Check: {priority}') + + if not priority: + n_cache_blocks = max(4, math.ceil(self.max_checkpoint_size / self.public_block_size) + 1) + if self.comm_overlap: + n_cache_blocks += 2 + n_cache_blocks = min(n_cache_blocks, self.public_block_number) + self.cuda_elements -= n_cache_blocks * self.public_block_size + if self.verbose: + print_rank_0(f'n_cache_block is set to {n_cache_blocks}') + else: + self.cuda_elements -= self.public_block_number * self.public_block_size + + def try_move_chunk_to_cuda(fused: bool): + for (i, plan) in enumerate(chunk_plans): + rcache_fused = plan.kwargs.get('rcache_fused', False) + if not fused and rcache_fused: + continue + elif fused and not rcache_fused: + break + param_os_size = os_factor * plan.chunk_size // self.default_group_size + if self.cuda_elements >= param_os_size: + plan.kwargs['shard_device'] = gpu_device() + self.cuda_elements -= param_os_size + else: + plan.kwargs['shard_device'] = torch.device('cpu') + plan.kwargs['cpu_pin_memory'] = True + if self.verbose: + print_rank_0(f"chunk {i}: shard device -> {plan.kwargs['shard_device']}") + + # check chunks that are not fused on rCache + try_move_chunk_to_cuda(False) + # check chunks that are fused on rCache + try_move_chunk_to_cuda(True) + + if not priority: + extra_blocks = math.floor(self.cuda_elements / self.public_block_size) + extra_blocks = min(extra_blocks, self.public_block_number - n_cache_blocks) + self.cuda_elements -= extra_blocks * self.public_block_size + self.public_block_number = n_cache_blocks + extra_blocks + if self.verbose: + print_rank_0(f'n_extra_blocks is set to {extra_blocks}') + + return chunk_plans + + +def optimal_search( + # pre-commit: do not rearrange + m: nn.Module, + group_size: int, + unified_dtype: torch.dtype = torch.float, + optimizer_type: str = 'Adam', + overlap: bool = False, + verbose: bool = False, + inp=None, + step_fn=None) -> SearchResult: + + search_class = SearchOptimal( + # pre-commit: do not rearrange + module=m, + default_group_size=group_size, + dtype=unified_dtype, + verbose=verbose, + overlap=overlap, + inp=inp, + step_fn=step_fn) + + private_group, public_groups = search_class.search() + chunk_plans = search_class.generate_chunk_plans(private_group, public_groups) + + if unified_dtype == torch.float16: + master_weight_factor = 2 + elif unified_dtype == torch.float: + master_weight_factor = 1 + else: + raise NotImplementedError + + if optimizer_type == 'SGD': + extra_sotre_factor = 1 + elif optimizer_type == 'Adam': + extra_sotre_factor = 2 + else: + raise NotImplementedError + + os_factor = 1 + (1 + extra_sotre_factor) * master_weight_factor + chunk_plans = search_class.configure_rcache_size(chunk_plans, os_factor) + chunk_group = search_class.allocate_chunk_group(chunk_plans) + + return SearchResult(chunk_group=chunk_group, + chunk_plans=chunk_plans, + param_called_per_step=search_class.param_per_step) diff --git a/colossalai/elixir/search/result.py b/colossalai/elixir/search/result.py new file mode 100644 index 000000000..793e7ac71 --- /dev/null +++ b/colossalai/elixir/search/result.py @@ -0,0 +1,32 @@ +from typing import Dict, List, NamedTuple + +import torch + +from colossalai.elixir.chunk import ChunkGroup + + +class ChunkPlan(NamedTuple): + """ChunkPlan is a type of configuration used to instruct the initialization of a chunk. + + args: + name_list: contains the names of parameters that should be pushed into this chunk + chunk_size: the size of this chunk + chunk_dtype: the dtype of this chunk + kwargs: a dictionary used in __init__ function of Chunk + """ + name_list: List[str] + chunk_size: int + chunk_dtype: torch.dtype + kwargs: Dict + + +class SearchResult(object): + + def __init__(self, + chunk_group: ChunkGroup, + chunk_plans: List[ChunkPlan], + param_called_per_step: List[List[str]] = None) -> None: + super().__init__() + self.chunk_group = chunk_group + self.param_chunk_plans = chunk_plans + self.param_called_per_step = param_called_per_step diff --git a/colossalai/elixir/search/simple.py b/colossalai/elixir/search/simple.py new file mode 100644 index 000000000..41b40f109 --- /dev/null +++ b/colossalai/elixir/search/simple.py @@ -0,0 +1,120 @@ +import math +from typing import Tuple + +import torch +import torch.nn as nn + +from .base import SearchBase +from .result import SearchResult +from .utils import get_multi_used_params, to_divide + + +class SearchSimple(SearchBase): + """The simple search algorithm used for unit tests. + Developers can specify the number of chunks used. + + args: + module: the module to be searched + default_group_size: the default group size of communications + dtype: the data type of the parameters + prefetch: whether to prefetch the chunks + verbose: whether to print the search process + inp: the example input of the model + step_fn: the example step function of training + """ + + def __init__(self, + module: nn.Module, + default_group_size: int, + dtype: torch.dtype = torch.float, + prefetch: bool = False, + verbose: bool = False, + inp=None, + step_fn=None) -> None: + + super().__init__(module, dtype, prefetch, verbose, inp, step_fn) + self.default_group_size = default_group_size + + def private_truncate(self, param: nn.Parameter) -> int: + return to_divide(param.numel(), self.default_group_size) + + def public_trucate(self, length: int) -> int: + return to_divide(length, self.default_group_size) + + def search(self, split_number: int, allocate_factor: float) -> Tuple: + # get multi-used parameters + private_params = get_multi_used_params(self.meta_module) + # get parameters used only one time + public_params = [p for p in self.meta_module.parameters() if p not in private_params] + + # calculate the size of each group + len_public = len(public_params) + split_number = min(len_public, split_number) + # allocate a list for groups + public_groups = list() + if split_number > 0: + average_size = len_public // split_number + left_size = len_public % split_number + + # set the size of each segment + pack_size_list = [average_size] * split_number + for i in range(split_number): + if left_size > 0: + pack_size_list[i] += 1 + left_size -= 1 + + # split public parameters + for i in range(split_number): + p_list = list() + for _ in range(pack_size_list[i]): + p = public_params.pop(0) + p_list.append(p) + public_groups.append(p_list) + assert len(public_params) == 0 + + # calculate the maximum summarized size + max_sum_size = 0 + for p_list in public_groups: + sum_size = sum([p.numel() for p in p_list]) + max_sum_size = max(max_sum_size, sum_size) + else: + max_sum_size = 0 + + self.public_block_size = max_sum_size + self.public_block_number = math.ceil(split_number * allocate_factor) + + return (private_params, public_groups) + + +def simple_search(m: nn.Module, + group_size: int, + split_number: int = 10, + allocate_factor: float = 0.6, + unified_dtype: torch.dtype = torch.float, + shard_device: torch.device = torch.device('cpu'), + prefetch: bool = False, + verbose: bool = False, + inp=None, + step_fn=None) -> SearchResult: + + search_class = SearchSimple( + # pre-commit: do not rearrange + module=m, + default_group_size=group_size, + dtype=unified_dtype, + prefetch=prefetch, + verbose=verbose, + inp=inp, + step_fn=step_fn) + + private_group, public_groups = search_class.search(split_number, allocate_factor) + chunk_plans = search_class.generate_chunk_plans(private_group, public_groups) + # assign shard device + for plan in chunk_plans: + plan.kwargs['shard_device'] = shard_device + + chunk_group = search_class.allocate_chunk_group(chunk_plans) + + return SearchResult(chunk_group=chunk_group, + chunk_plans=chunk_plans, + param_called_per_step=search_class.param_per_step) diff --git a/colossalai/elixir/search/simulator.py b/colossalai/elixir/search/simulator.py new file mode 100644 index 000000000..21cc1b96f --- /dev/null +++ b/colossalai/elixir/search/simulator.py @@ -0,0 +1,112 @@ +import math + +from .utils import to_divide + + +def calc_move_times(param_per_step: list, param_to_chunk: dict, n_blocks: int): + from colossalai.elixir.simulator import move_count + chunk_per_step = list() + + for param_set in param_per_step: + id_set = set() + for name in param_set: + # continue if the parameter is ignored + if name not in param_to_chunk: + continue + id_set.add(param_to_chunk[name]) + if len(id_set) > 0: + chunk_per_step.append(list(id_set)) + + return move_count(chunk_per_step, n_blocks) + + +def find_optimal_chunk_size( + # pre-commit: do not rearrange + param_per_step: list, + param_names: list, + param_numels: list, + cuda_elements: int, + overlap: bool, + min_range: int, + max_range: int, + interval: int): + + max_numel = 0 + for numel in param_numels: + max_numel = max(max_numel, numel) + test_size = to_divide(max(max_numel, min_range), interval) + # floor rounding + cuda_elements = to_divide(cuda_elements - interval + 1, interval) + max_range = min(max_range, cuda_elements) + + min_move_elements = float('+inf') + best_size = test_size + best_number_blocks = 0 + best_waste = 0 + + def dispatch_chunks(param_to_chunk: dict, block_size: int) -> int: + chunk_id = 0 + acc = 0 + left = 0 + for (name, numel) in zip(param_names, param_numels): + if numel > left: + acc += left + chunk_id += 1 + left = block_size + left -= numel + param_to_chunk[name] = chunk_id + return (chunk_id, left + acc) + + assert test_size <= max_range, 'max_numel or min_range is larger than max_range or cuda capacity' + while test_size <= max_range: + # calculate the number of blocks + number_blocks = int(cuda_elements // test_size) + # if prefetch is enabled, we pretend that two chunks are reserved + if overlap: + number_blocks -= 2 + if number_blocks <= 0: + continue + # initialize the chunk id for each parameter + param_to_chunk = dict() + number_chunks, current_waste = dispatch_chunks(param_to_chunk, test_size) + number_blocks = min(number_blocks, number_chunks) + # calculate the minimum number of movements + move_times = calc_move_times(param_per_step, param_to_chunk, number_blocks) + + current_move_elements = move_times * test_size + # print("test", test_size, current_move_elements) + if current_move_elements < min_move_elements: + min_move_elements = current_move_elements + best_size = test_size + best_number_blocks = number_blocks + best_waste = current_waste + + test_size += interval + + if min_move_elements == float('inf'): + raise RuntimeError('optimal search: can not find a valid solution') + + return best_size, best_number_blocks, best_waste + + +def bandwidth_c2g(n: int): + return 16.3 * n + 8.7 + + +def bandwidth_g2c(n: int): + return 15.8 * n + 2.3 + + +def velocity_gpu(n: int): + return 50 * n + + +def velocity_cpu(n: int): + return 1.66 * math.log(n) + 5.15 + + +def rcache_prioirity_check(n: int, r_os: int, e_p: int, e_o: int): + In = e_p / bandwidth_c2g(n) + e_p / bandwidth_g2c(n) + Jn = (n / r_os) * (e_o / bandwidth_c2g(n) + In + e_p / bandwidth_g2c(n) + 1.0 / velocity_cpu(n) - + 1.0 / velocity_gpu(n)) + return In > Jn diff --git a/colossalai/elixir/search/utils.py b/colossalai/elixir/search/utils.py new file mode 100644 index 000000000..1e2bc65b5 --- /dev/null +++ b/colossalai/elixir/search/utils.py @@ -0,0 +1,108 @@ +from typing import List, Set + +import torch +import torch.nn as nn + + +def to_divide(a: int, b: int): + return a + (-a % b) + + +def to_meta_tensor(t: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor: + # only float tensors need dtype change + if t.is_floating_point() and dtype is not None: + meta_dtype = dtype + else: + meta_dtype = t.dtype + # we shall not use t.data.to here, since t might be a fake tensor + meta_t = torch.empty(t.size(), dtype=meta_dtype, device='meta') + # pack it if t is a parameter + # we should filter parameters with no grad + if isinstance(t, nn.Parameter) and t.requires_grad: + meta_t = nn.Parameter(meta_t) + return meta_t + + +def get_multi_used_params(m: nn.Module) -> Set[torch.Tensor]: + multi_used_set = set() + visit = dict() + for module in m.modules(): + for param in module.parameters(recurse=False): + if param not in visit: + visit[param] = True + else: + multi_used_set.add(param) + return multi_used_set + + +def find_minimum_waste_size(numel_group_list: List[List[int]], min_range: int, max_range: int, interval: int): + + max_per_group = list() + for n_list in numel_group_list: + max_per_group.append(max(n_list)) + max_numel = max(max_per_group) + + test_size = to_divide(max(max_numel, min_range), interval) + best_size = test_size + min_waste = float('+inf') + + def calc_waste(numel_list: List[int], block_size: int): + acc = 0 + left = 0 + for s in numel_list: + if s > left: + acc += left + left = block_size + left -= s + return left + acc + + assert test_size <= max_range, 'max_numel or min_range is larger than max_range' + while test_size <= max_range: + current_waste = 0 + for n_list in numel_group_list: + current_waste += calc_waste(n_list, test_size) + if current_waste < min_waste: + best_size = test_size + min_waste = current_waste + test_size += interval + + return best_size, min_waste + + +def find_search_range(m: nn.Module): + + ele_size = 0 + for param in m.parameters(): + if ele_size == 0: + ele_size = param.element_size() + else: + assert param.element_size() == ele_size + + def next_2_pow(x: int): + y = 1 + while y < x: + y <<= 1 + return y + + private_params = get_multi_used_params(m) + params = [p for p in m.parameters() if p not in private_params] + memo_list = [p.numel() * p.element_size() for p in params] + max_memo = max(memo_list) + # minimum chunk memory is 32 MiB + default_min = 32 * 1024**2 + while default_min < max_memo: + default_min <<= 1 + default_max = int(3 * default_min) + # * 2 for forward and backward + length = 2 * next_2_pow(len(params)) + default_iter_times = 16 * 1024**2 + default_search_times = default_iter_times // length + + gap = default_max - default_min + # minimum search interval is 1024 + if default_search_times > (gap // 1024): + interval = 1024 + else: + interval = gap // default_search_times + + return (default_min // ele_size, default_max // ele_size, interval // ele_size) diff --git a/colossalai/elixir/simulator.cpp b/colossalai/elixir/simulator.cpp new file mode 100644 index 000000000..973cbf6bb --- /dev/null +++ b/colossalai/elixir/simulator.cpp @@ -0,0 +1,61 @@ +#include +#include +#include + +int move_count_impl(std::vector> &steps, int n_blocks) { + int n_steps = steps.size(); + std::unordered_map my_map; + std::map, int> next_map; + + for (auto i = n_steps - 1; ~i; --i) { + auto ids = steps.at(i); + for (auto c_id : ids) { + auto iter = my_map.find(c_id); + auto nxt = n_steps; + if (iter != my_map.end()) nxt = iter->second; + next_map.emplace(std::make_pair(i, c_id), nxt); + my_map[c_id] = i; + } + } + // reuse this map + for (auto iter : my_map) my_map[iter.first] = 0; + + int cache_size = 0, count = 0; + std::priority_queue> cache; + for (auto i = 0; i < n_steps; ++i) { + auto ids = steps.at(i); + assert(n_blocks >= ids.size()); + + int not_in = 0; + for (auto c_id : ids) + if (my_map[c_id] == 0) ++not_in; + + while (cache_size + not_in > n_blocks) { + std::pair q_top = cache.top(); + cache.pop(); + assert(q_top.first > i); + assert(my_map[q_top.second] == 1); + my_map[q_top.second] = 0; + --cache_size; + ++count; + } + + for (auto c_id : ids) { + auto iter = next_map.find(std::make_pair(i, c_id)); + cache.push(std::make_pair(iter->second, c_id)); + if (my_map[c_id] == 0) { + my_map[c_id] = 1; + ++cache_size; + } + } + } + return (count + cache_size) << 1; +} + +int move_count(std::vector> &steps, int n_blocks) { + return move_count_impl(steps, n_blocks); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("move_count", &move_count, "Count the number of moves."); +} diff --git a/colossalai/elixir/tensor.py b/colossalai/elixir/tensor.py new file mode 100644 index 000000000..40bac0307 --- /dev/null +++ b/colossalai/elixir/tensor.py @@ -0,0 +1,112 @@ +import torch +from torch.utils._pytree import tree_map + +debug_flag = False + +white_list = {torch.Tensor.__getitem__} + +fake_allowed = { + # pre-commit: don't move + torch.Tensor.numel, + torch.Tensor.size, + torch.Tensor.stride, + torch.Tensor.storage_offset, + torch.Tensor.is_floating_point +} + +inpalce_mapping = { + torch.Tensor.add_: torch.Tensor.add, + torch.Tensor.sub_: torch.Tensor.sub, + torch.Tensor.mul_: torch.Tensor.mul, + torch.Tensor.div_: torch.Tensor.div +} + + +def is_no_hook_op(func) -> bool: + if func.__name__.startswith('__') and func not in white_list: + return True + if func in fake_allowed: + return True + return False + + +class FakeTensor(torch.Tensor): + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_wrapper_subclass(cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad) + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise NotImplementedError + + +def to_outplace_tensor(t): + if isinstance(t, OutplaceTensor): + return t + assert type(t) is torch.Tensor, f'type: {type(t)}' + t.__class__ = OutplaceTensor + return t + + +class OutplaceTensor(torch.Tensor): + # TODO: rename this class + def __new__(cls, tensor): + rt = tensor.as_subclass(cls) + return rt + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + + if kwargs is None: + kwargs = {} + # in order to trigger pre-op hook in the forward of checkpoint module + # we have to capture the `backward` function + # and make sure that it does not in `torch._C.DisableTorchFunction()` context + if func is torch.Tensor.backward: + assert len(args) == 1 # only has 1 paramter + backward_tensor = torch.Tensor(args[0]) + tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} + return backward_tensor.backward(**tensor_kwargs) + # return a tensor if the output needs to be a torch.Tensor (such as Tensor.data.__get__) + if is_no_hook_op(func): + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + return ret + + # debug inplace operations + if debug_flag: + if func.__name__.endswith('_'): + print(f'found inplace operation {func.__name__}') + + # replace the in-place function + if func in inpalce_mapping: + func = inpalce_mapping[func] + # set the 'inplace' kwargs to False + if 'inplace' in kwargs: + kwargs['inplace'] = False + + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + if not isinstance(ret, tuple): + ret = (ret,) + + def convert(t): + if isinstance(t, torch.Tensor): + t = to_outplace_tensor(t) + return t + + ret = tree_map(convert, ret) + + if len(ret) == 1: + ret = ret[0] + + return ret diff --git a/colossalai/elixir/tracer/__init__.py b/colossalai/elixir/tracer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/elixir/tracer/memory_tracer/__init__.py b/colossalai/elixir/tracer/memory_tracer/__init__.py new file mode 100644 index 000000000..8ab6be844 --- /dev/null +++ b/colossalai/elixir/tracer/memory_tracer/__init__.py @@ -0,0 +1,2 @@ +from .cuda_profiler import cuda_memory_profiling +from .memory_tensor import MTensor diff --git a/colossalai/elixir/tracer/memory_tracer/cuda_profiler.py b/colossalai/elixir/tracer/memory_tracer/cuda_profiler.py new file mode 100644 index 000000000..4a58374a4 --- /dev/null +++ b/colossalai/elixir/tracer/memory_tracer/cuda_profiler.py @@ -0,0 +1,77 @@ +import gc +from typing import Callable, Dict, Tuple, Union + +import torch +import torch.nn as nn +from torch.utils._pytree import tree_map + +from colossalai.elixir.tracer.utils import get_cuda_allocated, meta_copy, model_memory_figure +from colossalai.elixir.utils import print_rank_0 + +from .memory_tensor import MTensor + + +def grad_cleaner(grad): + empty_grad = torch.empty_like(grad.elem) + grad.elem = None + empty_grad.storage().resize_(0) + return empty_grad + + +def cuda_memory_profiling(model: nn.Module, inp: Dict, step_fn: Callable, dtype=torch.float): + assert isinstance(inp, dict), 'the example input should be a dictionary' + print_rank_0(f'You are profiling cuda memory with dtype `{dtype}`') + + def tensor_trans(t: torch.Tensor): + # set dtype for tensors + meta_dtype = dtype if t.is_floating_point() else t.dtype + meta_t = torch.empty_like(t.data, device='meta', dtype=meta_dtype) + # pack parameters + if isinstance(t, nn.Parameter): + meta_t = nn.Parameter(meta_t) + return meta_t + + # first, transform the model into one dtype + model = meta_copy(model, tensor_trans) + # get the memory firgure of the model + memo_dict = model_memory_figure(model) + # initialize a empty pool for parameters + pool = torch.zeros(memo_dict['param_max_numel'], device='cuda', dtype=dtype) + + def tensor_to_cuda(t): + if isinstance(t, nn.Parameter): + fake_data = pool[:t.numel()].view(t.shape) + return nn.Parameter(fake_data) + else: + fake_data = torch.zeros(t.shape, device='cuda', dtype=t.dtype) + return fake_data + + # make all parameters in CUDA and point to a same address + model = meta_copy(model, tensor_to_cuda) + # add hooks to clean gradients + for param in model.parameters(): + param.register_hook(grad_cleaner) + + def input_trans(t): + if isinstance(t, torch.Tensor): + cuda_dtype = dtype if t.is_floating_point() else t.dtype + cuda_t = t.data.clone() + cuda_t = cuda_t.to(dtype=cuda_dtype, device='cuda') + cuda_t.requires_grad = t.requires_grad + return MTensor(cuda_t) + return t + + inp = tree_map(input_trans, inp) + # reset all collected peak memory states + MTensor.reset_peak_memory() + before_cuda_alc = get_cuda_allocated() + + step_fn(model, inp) + + after_cuda_alc = MTensor.current_peak_memory() + activation_occ = after_cuda_alc - before_cuda_alc + + return dict(param_occ=memo_dict['param_occ'], + buffer_occ=memo_dict['buffer_occ'], + grad_occ=memo_dict['param_occ'], + activation_occ=activation_occ) diff --git a/colossalai/elixir/tracer/memory_tracer/memory_tensor.py b/colossalai/elixir/tracer/memory_tracer/memory_tensor.py new file mode 100644 index 000000000..6c9d4657f --- /dev/null +++ b/colossalai/elixir/tracer/memory_tracer/memory_tensor.py @@ -0,0 +1,99 @@ +import contextlib +from typing import Iterator + +import torch +from torch.utils._pytree import tree_map + +from colossalai.elixir.tracer.utils import get_cuda_max_allocated + +from .op_cache import wrapped_mm_ops + +aten = torch.ops.aten + +mm_ops_list = [aten.mm.default, aten.addmm.default, aten.bmm.default, aten.addbmm.default, aten.baddbmm.default] + + +@contextlib.contextmanager +def no_dispatch() -> Iterator[None]: + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +class MTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + peak_memory_allocated: int = 0 + + @staticmethod + def reset_peak_memory(): + torch.cuda.reset_peak_memory_stats() + MTensor.peak_memory_allocated = 0 + + @staticmethod + def update_peak_memory(new_peak): + MTensor.peak_memory_allocated = max(MTensor.peak_memory_allocated, new_peak) + + @staticmethod + def current_peak_memory(): + cur_peak = get_cuda_max_allocated() + return max(MTensor.peak_memory_allocated, cur_peak) + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + # TODO: clone strides and storage aliasing + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad) + r.elem = elem + return r + + def __repr__(self): + return f'MTensor({self.elem})' + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + + def print_tensor(x): + if isinstance(x, torch.Tensor): + print(x.shape) + + # tree_map(print_tensor, args) + # tree_map(print_tensor, kwargs) + + def unwrap(x): + return x.elem if isinstance(x, MTensor) else x + + def wrap(x): + return MTensor(x) if isinstance(x, torch.Tensor) else x + + if func in mm_ops_list: + res, pre_max = wrapped_mm_ops(func, *tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + MTensor.update_peak_memory(pre_max) + else: + with no_dispatch(): + res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + + outs = normalize_tuple(res) + res = tree_map(wrap, outs) + + if len(res) == 1: + return res[0] + else: + return res diff --git a/colossalai/elixir/tracer/memory_tracer/op_cache.py b/colossalai/elixir/tracer/memory_tracer/op_cache.py new file mode 100644 index 000000000..33b791743 --- /dev/null +++ b/colossalai/elixir/tracer/memory_tracer/op_cache.py @@ -0,0 +1,135 @@ +import contextlib +from typing import Dict, Iterator, Tuple + +import torch + +from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated + +from .output_shape import addmm_output, bmm_output, check_cuda_mm, mm_output + + +@contextlib.contextmanager +def no_dispatch() -> Iterator[None]: + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def tensor_info(x: torch.Tensor): + # returns the meta information used for CUDA kernels + return (x.shape, x.stride(), x.layout, x.dtype) + + +def get_args_info(*args): + # returns a tuple contains the meta information of all inputs + # every argument is expected to be a tensor + info_list = [] + for x in args: + if isinstance(x, torch.Tensor): + info_list.append(tensor_info(x)) + return tuple(info_list) + + +class OpCache(object): + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + self.temp_memory: Dict[Tuple, int] = dict() + + def reset(self): + self.temp_memory.clear() + + def get(self, info): + if info in self.temp_memory: + return True, self.temp_memory[info] + else: + return False, None + + def add(self, info, memo): + self.temp_memory[info] = memo + + def print(self): + print(f'OpCache {self.name} information:') + for k, v in self.temp_memory.items(): + print(f'key: {k}\ntemp_memo:{v}') + + +aten = torch.ops.aten +addmm_cache = OpCache('aten.addmm.default') +bmm_cache = OpCache('aten.bmm.default') +mm_cache = OpCache('aten.mm.default') + +op_mapping = { + aten.mm.default: { + 'cache': mm_cache, + 'output': mm_output + }, + aten.addmm.default: { + 'cache': addmm_cache, + 'output': addmm_output + }, + aten.bmm.default: { + 'cache': bmm_cache, + 'output': bmm_output + } +} + + +def reset_caches(): + addmm_cache.reset() + bmm_cache.reset() + mm_cache.reset() + + +def fake_cuda_output(temp_memo, output_shape, dtype): + ret = torch.empty(output_shape, dtype=dtype, device='cuda') + sub = temp_memo - ret.numel() * ret.element_size() + + if sub > 0: + # allocate a temp empty tensor block to simulate the computation in kernels + temp = torch.empty(sub, dtype=torch.int8, device='cuda') + # release this tensor block + del temp + + return ret + + +def real_cuda_output(func, *args, **kwargs): + cur_alc = get_cuda_allocated() + # save the peak memory usage + pre_max_alc = get_cuda_max_allocated() + # the peak memory history is cleared here + torch.cuda.reset_peak_memory_stats() + + with no_dispatch(): + ret = func(*args, **kwargs) + + max_alc = get_cuda_max_allocated() + # calculate the temporary memory allocation + temp_memo = max_alc - cur_alc + + return ret, temp_memo, pre_max_alc + + +def wrapped_mm_ops(func, *args, **kwargs): + check_cuda_mm(*args) + + if func not in op_mapping: + raise RuntimeError(f'Unsupported mm operation {func}') + + args_info = get_args_info(*args) + cache = op_mapping[func]['cache'] + cached_flag, temp_memo = cache.get(args_info) + + if cached_flag: + output_fn = op_mapping[func]['output'] + out_shape = output_fn(*args) + ret = fake_cuda_output(temp_memo=temp_memo, output_shape=out_shape, dtype=args[0].dtype) + return ret, 0 + else: + ret, temp_memo, pre_max_alc = real_cuda_output(func, *args, **kwargs) + cache.add(args_info, temp_memo) + return ret, pre_max_alc diff --git a/colossalai/elixir/tracer/memory_tracer/output_shape.py b/colossalai/elixir/tracer/memory_tracer/output_shape.py new file mode 100644 index 000000000..7d32dccd4 --- /dev/null +++ b/colossalai/elixir/tracer/memory_tracer/output_shape.py @@ -0,0 +1,48 @@ +import torch + + +# Output functions come from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py +def check_cuda_mm(*args): + for x in args: + assert isinstance(x, torch.Tensor) + assert x.device.type == 'cuda' + + +def mm_output(a, b): + assert a.dim() == 2, 'a must be 2D' + assert b.dim() == 2, 'b must be 2D' + N, M1 = a.shape + M2, P = b.shape + assert M1 == M2, 'a and b must have same reduction dim' + return (N, P) + + +def addmm_output(bias, x, y): + return mm_output(x, y) + + +def common_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None): + assert batch1.dim() == 3, 'batch1 must be a 3D tensor' + assert batch2.dim() == 3, 'batch2 must be a 3D tensor' + + batch1_sizes = batch1.size() + batch2_sizes = batch2.size() + + bs = batch1_sizes[0] + contraction_size = batch1_sizes[2] + res_rows = batch1_sizes[1] + res_cols = batch2_sizes[2] + output_size = (bs, res_rows, res_cols) + + assert batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size + + if not is_bmm and self_baddbmm is not None: + assert self_baddbmm.dim() == 3, 'self must be a 3D tensor' + assert self_baddbmm.size() == output_size, \ + f'Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}' + + return output_size + + +def bmm_output(mat1, mat2): + return common_baddbmm_bmm(mat1, mat2, True) diff --git a/colossalai/elixir/tracer/ops.py b/colossalai/elixir/tracer/ops.py new file mode 100644 index 000000000..580e9f5e1 --- /dev/null +++ b/colossalai/elixir/tracer/ops.py @@ -0,0 +1,76 @@ +import torch +import torch.distributed as dist + +aten = torch.ops.aten + +__all__ = [ + 'TorchFactoryMethod', 'TorchOverrideableFactoryMethod', 'TorchNonOverrideableFactoryMethod', 'TensorPropertyMethod', + 'DistCommMethod', 'AliasATen', 'InplaceATen', 'MaybeInplaceAten', 'SameStorageAten' +] + +TorchOverrideableFactoryMethod = [ + 'empty', + 'eye', + 'full', + 'ones', + 'rand', + 'randn', + 'zeros', +] + +TorchNonOverrideableFactoryMethod = [ + 'arange', + 'finfo', + 'linspace', + 'logspace', + 'randint', + 'randperm', + 'tensor', +] + +TorchFactoryMethod = TorchOverrideableFactoryMethod + TorchNonOverrideableFactoryMethod + +TensorPropertyMethod = ['dtype', 'shape', 'device', 'requires_grad', 'grad', 'grad_fn', 'data'] + +DistCommMethod = [ + 'all_gather', + 'all_reduce', + 'all_to_all', + 'broadcast', + 'gather', + 'reduce', + 'reduce_scatter', + 'scatter', +] + +AliasATen = [ + aten.detach.default, + aten.detach_.default, + aten.t.default, + aten.transpose.int, + aten.view.default, + aten._unsafe_view.default, + aten._reshape_alias.default, +] + +InplaceATen = [ + aten.add_.Tensor, + aten.add_.Scalar, + aten.sub_.Tensor, + aten.sub_.Scalar, + aten.mul_.Tensor, + aten.mul_.Scalar, + aten.div_.Tensor, + aten.div_.Scalar, + aten.pow_.Tensor, + aten.pow_.Scalar, +] + +MaybeInplaceAten = [ + aten.diagonal.default, + aten.select.int, + aten.slice.Tensor, + aten.as_strided.default, +] + +SameStorageAten = AliasATen + InplaceATen + MaybeInplaceAten diff --git a/colossalai/elixir/tracer/param_tracer/__init__.py b/colossalai/elixir/tracer/param_tracer/__init__.py new file mode 100644 index 000000000..ae5024fa2 --- /dev/null +++ b/colossalai/elixir/tracer/param_tracer/__init__.py @@ -0,0 +1,3 @@ +from .fx_order import generate_fx_order +from .td_order import generate_td_order +from .tf_order import generate_tf_order diff --git a/colossalai/elixir/tracer/param_tracer/fx_order.py b/colossalai/elixir/tracer/param_tracer/fx_order.py new file mode 100644 index 000000000..563297987 --- /dev/null +++ b/colossalai/elixir/tracer/param_tracer/fx_order.py @@ -0,0 +1,62 @@ +from typing import Dict, List + +import torch +import torch.nn as nn +from torch.fx import GraphModule, Node, symbolic_trace + +from colossalai.elixir.tracer.utils import meta_copy + + +def generate_fx_order(model: nn.Module) -> List[Dict[str, nn.Parameter]]: + fxf_name_mark = '_fxf_name' + fxf_param_mark = '_fxf_param' + + def tensor_trans(t): + meta_t = t.data.to('meta') + if isinstance(t, nn.Parameter): + meta_t = nn.Parameter(meta_t) + return meta_t + + meta_model = meta_copy(model, tensor_trans) + + # attach names for parameters + for name, param in meta_model.named_parameters(): + setattr(param, fxf_name_mark, name) + + fx_forward_order: List[Dict[str, nn.Parameter]] = list() + + gm: GraphModule = symbolic_trace(meta_model) + + for node in gm.graph.nodes: + if node.op in ('output', 'placeholder'): + continue + + step_dict = None + if node.op == 'get_attr': + maybe_param = getattr(gm, node.target) + # mark this node as a parameter + if maybe_param is not None: + setattr(node, fxf_param_mark, maybe_param) + continue + elif node.op == 'call_module': + target_module = gm.get_submodule(node.target) + step_dict = dict() + # collect all parameters in the module + for maybe_param in target_module.parameters(): + if maybe_param is not None: + param_name = getattr(maybe_param, fxf_name_mark) + step_dict[param_name] = maybe_param + elif node.op in ('call_function', 'call_method'): + step_dict = dict() + for pre in node.args: + if hasattr(pre, fxf_param_mark): + param = getattr(pre, fxf_param_mark) + param_name = getattr(param, fxf_name_mark) + step_dict[param_name] = param + else: + raise RuntimeError(f'Unsupported node op {node.op}!') + + if step_dict is not None and len(step_dict) > 0: + fx_forward_order.append(step_dict) + + return fx_forward_order diff --git a/colossalai/elixir/tracer/param_tracer/td_order.py b/colossalai/elixir/tracer/param_tracer/td_order.py new file mode 100644 index 000000000..2eae0e2b7 --- /dev/null +++ b/colossalai/elixir/tracer/param_tracer/td_order.py @@ -0,0 +1,156 @@ +import contextlib +import uuid +from typing import Callable, Dict, Iterator, List, Tuple, Union + +import torch +import torch.nn as nn +from torch.utils._pytree import tree_map + +from colossalai.elixir.tracer.ops import SameStorageAten +from colossalai.elixir.tracer.utils import meta_copy + + +@contextlib.contextmanager +def no_dispatch() -> Iterator[None]: + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def register_storage(x): + assert isinstance(x, nn.Parameter) + assert x.data_ptr() == 0 + + data_ptr = uuid.uuid1() + x.data_ptr = lambda: data_ptr + + +class ATensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + data_ptr_dict: Dict[int, Tuple[str, nn.Parameter]] = None + order_list: List[Dict] = None + + @staticmethod + def reset(): + ATensor.data_ptr_dict = dict() + ATensor.order_list = list() + + @staticmethod + def clear(): + ATensor.data_ptr_dict = None + ATensor.order_list = None + + @staticmethod + def add_data_ptr(name: str, param: nn.Parameter): + data_ptr = param.data_ptr() + if data_ptr not in ATensor.data_ptr_dict: + ATensor.data_ptr_dict[data_ptr] = (name, param) + else: + name_in, param_in = ATensor.data_ptr_dict[data_ptr] + if name != name_in or id(param) != id(param_in): + raise RuntimeError('Got two different parameters with the same data ptr') + + @staticmethod + def get_param(data_ptr: int): + if data_ptr in ATensor.data_ptr_dict: + return ATensor.data_ptr_dict.get(data_ptr) + else: + return None, None + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + # TODO: clone strides and storage aliasing + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad) + r.elem = elem + return r + + def __repr__(self): + return f'ATensor({self.elem})' + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + step_dict = dict() + + def record_param(x): + if isinstance(x, torch.Tensor): + name, param = ATensor.get_param(x.data_ptr()) + if name is not None: + step_dict[name] = param + + def debug_tensor(x): + if isinstance(x, torch.Tensor): + print(type(x), x.shape, x.data_ptr(), id(x)) + if x.grad_fn: + print(x.grad_fn) + + tree_map(record_param, args) + if len(step_dict) > 0: + ATensor.order_list.append(step_dict) + del step_dict + + def unwrap(x): + return x.elem if isinstance(x, ATensor) else x + + def wrap(x): + return ATensor(x) if isinstance(x, torch.Tensor) else x + + with no_dispatch(): + res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) + outs = normalize_tuple(res) + res = tree_map(wrap, outs) + + if func in SameStorageAten: + for x in res: + if isinstance(x, torch.Tensor): + x.data_ptr = args[0].data_ptr + + if len(res) == 1: + return res[0] + else: + return res + + +def generate_td_order(model: nn.Module, inp: Union[torch.Tensor, Tuple], step_fn: Callable): + ATensor.reset() + + def tensor_trans(t): + meta_t = ATensor(t.data.to('meta')) + if isinstance(t, nn.Parameter): + meta_t = nn.Parameter(meta_t) + return meta_t + + model = meta_copy(model, tensor_trans) + for name, param in model.named_parameters(): + register_storage(param) + ATensor.add_data_ptr(name, param) + + # convert all input data to meta_tensor + if not isinstance(inp, tuple): + inp = (inp,) + inp = tree_map(lambda t: ATensor(torch.empty_like(t, device='meta', requires_grad=t.requires_grad)), inp) + + step_fn(model, inp) + + ret = ATensor.order_list + ATensor.clear() + + return ret diff --git a/colossalai/elixir/tracer/param_tracer/tf_order.py b/colossalai/elixir/tracer/param_tracer/tf_order.py new file mode 100644 index 000000000..fcabcf235 --- /dev/null +++ b/colossalai/elixir/tracer/param_tracer/tf_order.py @@ -0,0 +1,178 @@ +from typing import Callable, Dict, List + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.utils._pytree import tree_map + +from colossalai.elixir.tensor import is_no_hook_op +from colossalai.elixir.tracer.utils import meta_copy +from colossalai.elixir.utils import no_dispatch, normalize_tuple + +torch_checkpoint_function = torch.utils.checkpoint.checkpoint + + +def attach_checkpoint(): + default_in_checkpoint = False + + def inner_checkpoint_function(function, *args, use_reentrant: bool = True, **kwargs): + nonlocal default_in_checkpoint + prev_in_checkpoint = default_in_checkpoint + default_in_checkpoint = True + # record the step where going into checkpoint + if not prev_in_checkpoint: + Record.record_in_checkpoint() + # use original torch checkpoint function + global torch_checkpoint_function + ret = torch_checkpoint_function(function, *args, use_reentrant=use_reentrant, **kwargs) + # roll back + default_in_checkpoint = prev_in_checkpoint + if not default_in_checkpoint: + Record.record_out_checkpoint() + return ret + + torch.utils.checkpoint.checkpoint = inner_checkpoint_function + + +def release_checkpoint(): + global torch_checkpoint_function + torch.utils.checkpoint.checkpoint = torch_checkpoint_function + + +class PostFwdPreBwd(torch.autograd.Function): + + @staticmethod + def forward(ctx, params, *args): + ctx.params = params + return args + + @staticmethod + def backward(ctx, *grads): + Record.record_params(ctx.params) + return (None, *grads) + + +class Record(nn.Parameter): + record_steps: List = None + checkpoint_info: List = None + in_checkpoint_step: int = -1 + + def __new__(cls, elem): + assert elem.device.type == 'meta', f'device type: {elem.device.type}' + r = torch.Tensor._make_subclass(cls, elem) + return r + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if is_no_hook_op(func): + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + return ret + + params = list() + + def append_param(x): + if isinstance(x, nn.Parameter): + assert isinstance(x, Record) + params.append(x) + + tree_map(append_param, args) + tree_map(append_param, kwargs) + Record.record_params(params) + + with torch._C.DisableTorchFunction(): + ret = normalize_tuple(func(*args, **kwargs)) + ret = PostFwdPreBwd.apply(params, *ret) + + def clone(t): + if isinstance(t, torch.Tensor): + t = t.clone() + return t + + ret = tree_map(clone, ret) + + if len(ret) == 1: + return ret[0] + else: + return ret + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + # notice: we should disable __torch_function__ here + # otherwise, unexpected operations are called inside meta kernels + with torch._C.DisableTorchFunction(): + with no_dispatch(): + return func(*args, **kwargs) + + @staticmethod + def reset(): + Record.record_steps = list() + Record.checkpoint_info = list() + Record.in_checkpoint_step = -1 + + @staticmethod + def record_in_checkpoint(): + assert Record.in_checkpoint_step == -1 + Record.in_checkpoint_step = len(Record.record_steps) + + @staticmethod + def record_out_checkpoint(): + assert Record.in_checkpoint_step != -1 + value_pair = (Record.in_checkpoint_step, len(Record.record_steps)) + Record.checkpoint_info.append(value_pair) + Record.in_checkpoint_step = -1 + + @staticmethod + def steps(): + ret = dict(params_per_step=Record.record_steps, checkpoint_info=Record.checkpoint_info) + Record.record_steps = None + Record.checkpoint_info = None + return ret + + @staticmethod + def record_params(params): + record_dict = {p.param_name for p in params} + Record.record_steps.append(record_dict) + + +def generate_tf_order(model: nn.Module, inp: Dict, step_fn: Callable, dtype: torch.dtype = torch.float): + assert isinstance(inp, dict), 'The example input should be a dictionary' + + Record.reset() + + def mtensor_trans(t: torch.Tensor): + if t.is_floating_point(): + meta_dtype = dtype + else: + meta_dtype = t.dtype + + meta_t = torch.empty_like(t, dtype=meta_dtype, device='meta') + if isinstance(t, nn.Parameter): + meta_t = Record(meta_t) + meta_t.requires_grad = t.requires_grad + return meta_t + + model = meta_copy(model, mtensor_trans) + for name, param in model.named_parameters(): + param.param_name = name + + def input_trans(t): + if isinstance(t, torch.Tensor): + if t.is_floating_point(): + meta_dtype = dtype + else: + meta_dtype = t.dtype + + meta_t = torch.empty_like(t, dtype=meta_dtype, device='meta', requires_grad=t.requires_grad) + return meta_t + return t + + inp = tree_map(input_trans, inp) + attach_checkpoint() + step_fn(model, inp) + release_checkpoint() + ret = Record.steps() + return ret diff --git a/colossalai/elixir/tracer/utils.py b/colossalai/elixir/tracer/utils.py new file mode 100644 index 000000000..d9bee8ff1 --- /dev/null +++ b/colossalai/elixir/tracer/utils.py @@ -0,0 +1,96 @@ +from collections import OrderedDict +from copy import copy +from typing import Optional, Set + +import torch +import torch.nn as nn + + +def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): + """Get a dfs module list of the given module. Its order is same as the order of creations of modules. + """ + if memo is None: + memo = set() + if module not in memo: + for name, submodule in module._modules.items(): + if submodule is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + for m in _get_dfs_module_list(submodule, memo, submodule_prefix): + yield m + + memo.add(module) + yield prefix, module + + +def _get_shallow_copy_model(model: nn.Module): + """Get a shallow copy of the given model. Each submodule is different from the original submodule. + But the new submodule and the old submodule share all attributes. + """ + old_to_new = dict() + for name, module in _get_dfs_module_list(model): + new_module = copy(module) + new_module._modules = OrderedDict() + for subname, submodule in module._modules.items(): + if submodule is None: + continue + setattr(new_module, subname, old_to_new[submodule]) + old_to_new[module] = new_module + return old_to_new[model] + + +def meta_copy(model: nn.Module, meta_fn: callable): + new_model = _get_shallow_copy_model(model) + old_parameters = dict() + old_buffers = dict() + + for (_, old_module), (_, new_module) in \ + zip(_get_dfs_module_list(model), _get_dfs_module_list(new_model)): + + new_module._parameters = OrderedDict() + for name, param in old_module._parameters.items(): + new_param = None + if param is not None: + param_id = id(param) + if param_id in old_parameters: + new_param = old_parameters.get(param_id) + else: + new_param = meta_fn(param) + old_parameters[param_id] = new_param + setattr(new_module, name, new_param) + + new_module._buffers = OrderedDict() + for name, buffer in old_module._buffers.items(): + new_buffer = None + if buffer is not None: + buffer_id = id(buffer) + if buffer_id in old_buffers: + new_buffer = old_buffers.get(buffer_id) + else: + new_buffer = meta_fn(buffer) + old_buffers[buffer_id] = new_buffer + new_module.register_buffer(name, new_buffer) + + return new_model + + +def get_cuda_allocated(): + return torch.cuda.memory_allocated() + + +def get_cuda_max_allocated(): + return torch.cuda.max_memory_allocated() + + +def model_memory_figure(model: nn.Module): + param_occ = 0 + max_numel = 0 + for name, param in model.named_parameters(): + param_occ += param.numel() * param.element_size() + max_numel = max(max_numel, param.numel()) + + buffer_occ = 0 + for name, buffer in model.named_buffers(): + buffer_occ += buffer.numel() * buffer.element_size() + + return dict(param_occ=param_occ, param_max_numel=max_numel, buffer_occ=buffer_occ) diff --git a/colossalai/elixir/utils.py b/colossalai/elixir/utils.py new file mode 100644 index 000000000..9abab65bf --- /dev/null +++ b/colossalai/elixir/utils.py @@ -0,0 +1,119 @@ +import contextlib +import os +import random + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn + + +@contextlib.contextmanager +def no_dispatch(): + guard = torch._C._DisableTorchDispatch() + try: + yield + finally: + del guard + + +def normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + +def init_distributed(): + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + + init_method = f'tcp://[{host}]:{port}' + dist.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=world_size) + + # set cuda device + if torch.cuda.is_available(): + # if local rank is not given, calculate automatically + torch.cuda.set_device(local_rank) + + seed_all(1024) + + +def print_rank_0(*args, **kwargs): + if dist.is_initialized(): + if dist.get_rank() == 0: + print(*args, **kwargs) + dist.barrier() + else: + print(*args, **kwargs) + + +def get_model_size(model: nn.Module): + total_numel = 0 + for module in model.modules(): + for p in module.parameters(recurse=False): + total_numel += p.numel() + return total_numel + + +def model_size_formatter(numel: int) -> str: + GB_SIZE = 10**9 + MB_SIZE = 10**6 + KB_SIZE = 10**3 + if numel >= GB_SIZE: + return f'{numel / GB_SIZE:.1f}B' + elif numel >= MB_SIZE: + return f'{numel / MB_SIZE:.1f}M' + elif numel >= KB_SIZE: + return f'{numel / KB_SIZE:.1f}K' + else: + return str(numel) + + +def calc_buffer_size(m: nn.Module, test_dtype: torch.dtype = torch.float): + max_sum_size = 0 + for module in m.modules(): + sum_p_size = 0 + for param in module.parameters(recurse=False): + assert param.dtype == test_dtype + sum_p_size += param.numel() + max_sum_size = max(max_sum_size, sum_p_size) + return max_sum_size + + +def calc_block_usage(): + snap_shot = torch.cuda.memory_snapshot() + + total_sum = 0 + active_sum = 0 + for info_dict in snap_shot: + blocks = info_dict.get('blocks') + for b in blocks: + size = b.get('size') + state = b.get('state') + total_sum += size + if state == 'active_allocated': + active_sum += size + + active_ratio = 1 + if total_sum > 0: + active_ratio = active_sum / total_sum + + print(f'memory snap shot: active ratio {active_ratio:.2f}') diff --git a/colossalai/elixir/wrapper/__init__.py b/colossalai/elixir/wrapper/__init__.py new file mode 100644 index 000000000..b4a592c58 --- /dev/null +++ b/colossalai/elixir/wrapper/__init__.py @@ -0,0 +1,2 @@ +from .module import ElixirModule +from .optimizer import ElixirOptimizer diff --git a/colossalai/elixir/wrapper/module.py b/colossalai/elixir/wrapper/module.py new file mode 100644 index 000000000..d6d0f2ec1 --- /dev/null +++ b/colossalai/elixir/wrapper/module.py @@ -0,0 +1,344 @@ +from collections import defaultdict +from copy import copy +from functools import partial +from typing import Any, Iterable, Mapping + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.utils._pytree import tree_map + +from colossalai.elixir.chunk import Chunk, ChunkFetcher, ChunkGroup, MemoryPool, TensorState +from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.hook import BufferStore, HookParam +from colossalai.elixir.search import SearchResult +from colossalai.elixir.tensor import OutplaceTensor +from colossalai.utils.model.experimental import LazyTensor + + +def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype): + param_data = param_data.to(gpu_device()) + optim_data = param_data.clone() if param_data.dtype == torch.float else param_data.float() + param_data = param_data.to(param_dtype) + return param_data, optim_data + + +class ElixirModule(nn.Module): + """Use this class to wrap your model when using Elixir. Don't know what should be written here. + But some docstring is needed here. + + args: + module: training module + search_result: a SearchResult generated from a search algorithm in `elixir.search` + process_group: the communication group, ussually dp parallel group + prefetch: whether to use prefetch overlaping communication with computation + dtype: the dtype used in training + """ + + def __init__(self, + module: nn.Module, + search_result: SearchResult, + process_group: ProcessGroup, + prefetch: bool = False, + dtype: torch.dtype = torch.float, + reduce_always_fp32: bool = False, + output_fp32: bool = False, + use_fused_kernels: bool = False) -> None: + super().__init__() + + assert dtype in {torch.float, torch.float16} + + self._set_module_outplace(module) + self.module = module + self.dtype = dtype + self.use_amp = (dtype == torch.float16) + self.process_group = process_group + self.prefetch_flag = prefetch + self.reduce_always_fp32 = reduce_always_fp32 + self.output_fp32 = output_fp32 + self.use_fused_kernels = use_fused_kernels + + self.no_grad_state_dict = dict() + self.grad_state_dict = dict() + self.__init_chunk_group(search_result) + self.__init_chunk_fetcher(search_result, prefetch) + self.__init_buffer_storage() + + for name, param in module.named_parameters(): + if not param.requires_grad: + assert name in self.no_grad_state_dict + continue + assert name in self.grad_state_dict + param.register_hook(partial(self._gradient_handler, param=param)) + param.__class__ = HookParam + + def __init_chunk_group(self, sr: SearchResult): + torch.cuda.empty_cache() + state_dict = self.module.state_dict(keep_vars=True) + for name, tensor in state_dict.items(): + if isinstance(tensor, nn.Parameter): + assert tensor.is_floating_point(), 'the dtypes of parameters should be float dtypes' + # deal with parameters + if tensor.requires_grad: + self.grad_state_dict[name] = tensor + else: + self.no_grad_state_dict[name] = tensor + # polish no-grad parameters + tensor.data = tensor.data.to(dtype=self.dtype, device=gpu_device()) + else: + # deal with buffers + self._lazy_init_check(tensor) + to_dtype = self.dtype if tensor.is_floating_point() else tensor.dtype + tensor.data = tensor.data.to(dtype=to_dtype, device=gpu_device()) + + empty_mp = MemoryPool('cuda') + empty_mp.allocate() + + self.param_chunk_group = sr.chunk_group + self.optim_chunk_group = ChunkGroup(empty_mp) + self.param_to_optim = dict() + vis_set = set() + + for plan in sr.param_chunk_plans: + assert plan.chunk_dtype == self.dtype + # optimizer chunks should not be gathered + optim_kwargs = copy(plan.kwargs) + if 'rcache_fused' in optim_kwargs: + optim_kwargs['rcache_fused'] = False + + p_chunk = self.param_chunk_group.open_chunk(chunk_size=plan.chunk_size, + chunk_dtype=plan.chunk_dtype, + process_group=self.process_group, + chunk_config=plan.kwargs) + o_chunk = self.optim_chunk_group.open_chunk(chunk_size=plan.chunk_size, + chunk_dtype=torch.float, + process_group=self.process_group, + chunk_config=optim_kwargs) + + for name in plan.name_list: + param = self.grad_state_dict[name] + self._lazy_init_check(param) + param_data, optim_data = get_param_optim_data(param.data, self.dtype) + param.data = param_data + p_chunk.append_tensor(param) + o_chunk.append_tensor(optim_data) + self.param_to_optim[param] = optim_data + + vis_set.add(param) + + self.param_chunk_group.close_chunk(p_chunk) + self.optim_chunk_group.close_chunk(o_chunk) + p_chunk.init_pair(o_chunk) + + # sanity check: every parameter needed gradient has been initialized + for param in self.module.parameters(): + if param.requires_grad: + assert param in vis_set + + def __init_chunk_fetcher(self, sr: SearchResult, prefetch: bool): + scheduler = None + if prefetch: + assert sr.param_called_per_step is not None + + chunk_called_per_step = list() + for step in sr.param_called_per_step: + step_set = set() + for name in step: + param = self.grad_state_dict[name] + chunk = self.param_chunk_group.ten_to_chunk[param] + step_set.add(chunk) + chunk_called_per_step.append(step_set) + + scheduler = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step) + else: + scheduler = FIFOScheduler() + + self.fetcher = ChunkFetcher(scheduler, + self.param_chunk_group, + overlap=prefetch, + reduce_always_fp32=self.reduce_always_fp32) + self.fetcher.reset() + + def __init_buffer_storage(self): + buffer_size = 0 + for submodule in self.modules(): + sum_param_size = 0 + for param in submodule.parameters(recurse=False): + if not param.requires_grad or self.fetcher.is_in_fused(param): + continue + assert param.dtype == self.dtype + sum_param_size += param.numel() + buffer_size = max(buffer_size, sum_param_size) + self.buffer = BufferStore(buffer_size, self.dtype) + print('module buffer', self.buffer) + + def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter): + # create an empty tensor + fake_grad = self.buffer.empty_like(grad) + + with torch._C.DisableTorchFunction(): + chunk = self.fetcher.get_one_chunk(param) + assert self.fetcher.group.is_accessed(chunk) + if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD: + raise RuntimeError() + self.fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) + chunk.copy_tensor_to_chunk_slice(param, grad) + self.fetcher.reduce_chunk(chunk) + + return fake_grad + + def _lazy_init_check(self, tensor: torch.Tensor) -> None: + if isinstance(tensor, LazyTensor): + tensor.materialize() + + def _set_module_outplace(self, m: nn.Module): + # set inplace to False for all modules + for module in m.modules(): + if hasattr(module, 'inplace'): + module.inplace = False + + def _deattach_fetcher(self): + self.fetcher.clear() + HookParam.release_fetcher() + HookParam.disable_fused_kernel() + + def _release_for_inference(self): + torch.cuda.synchronize() + + scheduler = self.fetcher.scheduler + param_group = self.param_chunk_group + while True: + maybe_chunk = scheduler.top() + if maybe_chunk is None: + break + scheduler.remove(maybe_chunk) + param_group.release_chunk(maybe_chunk) + self._deattach_fetcher() + + def forward(self, *args, **kwargs): + if torch.is_grad_enabled(): + inference_mode = False + else: + inference_mode = True + + # reset the fetcher in this step + self.fetcher.reset() + HookParam.attach_fetcher(self.fetcher, self.buffer) + if self.use_fused_kernels: + HookParam.enable_fused_kernel() + + def to_outplace_tensor(t): + if isinstance(t, torch.Tensor): + if t.is_floating_point(): + t = t.to(self.dtype) + t = OutplaceTensor(t) + return t + + args = tree_map(to_outplace_tensor, args) + kwargs = tree_map(to_outplace_tensor, kwargs) + + outputs = self.module(*args, **kwargs) + if self.output_fp32: + outputs = outputs.float() + + if inference_mode: + self._release_for_inference() + + return outputs + + def backward(self, loss: torch.Tensor): + loss.backward() + # reset the fetcher for the next step + self._deattach_fetcher() + # reset all attributes + self.module.zero_grad(set_to_none=True) + + def state_dict(self, + destination=None, + prefix='', + keep_vars=False, + only_rank_0: bool = False, + from_param: bool = False): + assert keep_vars is False, 'state_dict can not keep variables in ElixirModule' + # make sure that the variables are kept, we shall detach them later + module_state_dict = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=True) + + tensor_to_names = defaultdict(list) + for name, tensor in module_state_dict.items(): + if isinstance(tensor, nn.Parameter) and tensor.requires_grad: + used_tensor = self.grad_state_dict[name] + if not from_param: + used_tensor = self.param_to_optim.get(used_tensor) + tensor_to_names[used_tensor].append(name) + else: + module_state_dict[name] = tensor.detach() + + def update_state_dict(chunks: Iterable[Chunk]): + for c in chunks: + for op, cp in zip(c.get_tensors(), c.get_cpu_copy(only_rank_0)): + for name in tensor_to_names.get(op): + module_state_dict[name] = cp + + if from_param: + used_group = self.param_chunk_group + else: + used_group = self.optim_chunk_group + + update_state_dict(used_group.fused_chunks) + update_state_dict(used_group.float_chunks) + + return module_state_dict + + def load_state_dict(self, state_dict: Mapping[str, Any], only_rank_0: bool = False): + load_flag = not only_rank_0 or dist.get_rank() == 0 + if not load_flag: + # only rank 0 loads the state dict + assert state_dict is None + + if only_rank_0: + # broadcast the length of the state dict + state_length = len(state_dict) if load_flag else None + comm_list = [state_length] + dist.broadcast_object_list(comm_list) + state_length = comm_list[0] + # broadcast the keys of the state dict + state_keys = state_dict.keys() if load_flag else [None] * state_length + dist.broadcast_object_list(state_keys) + # update the state dict + if not load_flag: + state_dict = {k: None for k in state_keys} + + # init the mapping from optim tensor to load tensor + optim_to_load = dict() + for name, maybe_tensor in state_dict.items(): + if name in self.no_grad_state_dict: + no_grad_param = self.no_grad_state_dict.get(name) + if load_flag: + no_grad_param.copy_(maybe_tensor) + if only_rank_0: + dist.broadcast(no_grad_param, src=0) + elif name in self.grad_state_dict: + grad_param = self.grad_state_dict.get(name) + optim_tensor = self.param_to_optim.get(grad_param) + optim_to_load[optim_tensor] = maybe_tensor + + def use_state_dict(chunks: Iterable[Chunk]): + for c in chunks: + load_tensor_list = list() + has_load = False + for chunk_tensor in c.tensors_info.keys(): + if chunk_tensor in optim_to_load: + has_load = True + load_tensor_list.append(optim_to_load[chunk_tensor]) + else: + load_tensor_list.append(None) + if has_load: + c.load_tensors(load_tensor_list, only_rank_0) + c.paired_chunk.optim_update() + + use_state_dict(self.optim_chunk_group.fused_chunks) + use_state_dict(self.optim_chunk_group.float_chunks) + + return diff --git a/colossalai/elixir/wrapper/optimizer.py b/colossalai/elixir/wrapper/optimizer.py new file mode 100644 index 000000000..8ddc0aa55 --- /dev/null +++ b/colossalai/elixir/wrapper/optimizer.py @@ -0,0 +1,265 @@ +import math +from collections import defaultdict +from enum import Enum +from typing import Dict, Set, Tuple + +import torch +import torch.distributed as dist +from torch.nn import Parameter + +import colossalai.nn.optimizer as colo_optim +from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler, ConstantGradScaler, DynamicGradScaler +from colossalai.elixir.chunk import Chunk +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.hook.storage import BufferStore +from colossalai.logging import get_dist_logger + +from .module import ElixirModule + +_AVAIL_OPTIM_LIST = {colo_optim.FusedAdam, colo_optim.CPUAdam, colo_optim.HybridAdam} + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + + +class ElixirOptimizer(colo_optim.ColossalaiOptimizer): + """A wrapper for optimizers. Users should notice that one specific ElixirOptimizer is strictly + corresponding to one ElixirModule. Currently only a group of optimizers are supported in ElixirOptimizer. + The reason is that ElixirOptimizer only support element-wise optimizers now. + We may enlarge the group of supported optimizers later. + + Args: + optim: The torch optimizer instance. + module: The nn.Module instance wrapped as an ElixirModule. + """ + + def __init__(self, + module: ElixirModule, + optimizer: torch.optim.Optimizer, + initial_scale: float = 32768, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**24, + max_norm: float = 0.0, + norm_type: float = 2.0, + init_step=False): + + super().__init__(optimizer) + assert isinstance(module, ElixirModule) + self.scaled_optimizer = False + if type(optimizer) in _AVAIL_OPTIM_LIST: + self.scaled_optimizer = True + + self.module = module + self.param_chunk_group = module.param_chunk_group + self.optim_chunk_group = module.optim_chunk_group + + self.optim_state = OptimState.UNSCALED + self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() + self.param_to_optim_chunk: Dict[Parameter, Chunk] = dict() + self.param_chunk_set: Set[Chunk] = self.param_chunk_group.fused_chunks.union( + self.param_chunk_group.float_chunks) + + self.clipping_flag = max_norm > 0.0 + self.max_norm = max_norm + if self.clipping_flag: + assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now' + + self.__init__optimizer() + + # Grad scaler + self.grad_scaler: BaseGradScaler = None + if module.use_amp: + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + else: + self.grad_scaler = ConstantGradScaler(1.0, verbose=False) + self._comm_buffer: torch.Tensor = torch.zeros(1, dtype=torch.float, device=gpu_device()) + self._logger = get_dist_logger() + + if init_step: + # allocate memory before training + self.__zero_step() + + def __zero_step(self): + torch.cuda.empty_cache() + + cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu') + buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer) + for _, zero_buffer in buffer_dict.items(): + zero_buffer.zeros() + + for group in self.param_groups: + for fake_param in group['params']: + optim_chunk = self.param_to_optim_chunk[fake_param] + begin, end = self.param_to_range[fake_param] + + fake_param.data = buffer_dict.get(optim_chunk.shard_device.type).empty_1d(end - begin) + fake_param.grad = fake_param.data + fake_param.data = optim_chunk.shard[begin:end] + + self.optim.step() + self.zero_grad() + self._update_fp16_params(update_flag=False) + + def _set_grad_ptr(self): + for group in self.param_groups: + for fake_param in group['params']: + optim_chunk = self.param_to_optim_chunk[fake_param] + begin, end = self.param_to_range[fake_param] + param_chunk = optim_chunk.paired_chunk + + fake_param.data = param_chunk.shard[begin:end] + fake_param.grad = fake_param.data + fake_param.data = optim_chunk.shard[begin:end] + + def _update_fp16_params(self, update_flag: bool = True): + none_tensor = torch.empty([0]) + for group in self.param_groups: + for fake_param in group['params']: + assert fake_param.grad is None + fake_param.data = none_tensor.to(fake_param.device) + + if update_flag: + for param_chunk in self.param_chunk_set: + param_chunk.optim_update() + + def _check_overflow(self) -> bool: + # calculate the overflow counter + overflow_counter = 0 + for param_chunk in self.param_chunk_set: + overflow_counter += int(param_chunk.overflow) + return overflow_counter > 0 + + def _clear_optim_states(self) -> None: + for param_chunk in self.param_chunk_set: + param_chunk.overflow = False + param_chunk.l2_norm = None + + def _calc_global_norm(self) -> float: + group_to_norm = defaultdict(float) + for param_chunk in self.param_chunk_set: + assert param_chunk.l2_norm is not None + assert not param_chunk.is_replica + + group_to_norm[param_chunk.torch_pg] += param_chunk.l2_norm + + norm_sqr = 0.0 + for group, part_norm in group_to_norm.items(): + self._comm_buffer.fill_(part_norm) + dist.all_reduce(self._comm_buffer, group=group) + norm_sqr += self._comm_buffer.item() + + global_norm = math.sqrt(norm_sqr) + return global_norm + + def _get_combined_scale(self): + loss_scale = 1 + + assert self.optim_state == OptimState.SCALED + loss_scale = self.loss_scale + self.optim_state = OptimState.UNSCALED + + combined_scale = loss_scale + if self.clipping_flag: + total_norm = self._calc_global_norm() + clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + if clip > 1: + combined_scale = clip * loss_scale + + if combined_scale == 1: + return -1 + else: + return combined_scale + + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + + def zero_grad(self, *args, **kwargs): + return self.optim.zero_grad(set_to_none=True) + + def step(self, *args, **kwargs): + self._set_grad_ptr() + found_inf = self._check_overflow() + + if found_inf: + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + self._logger.info(f'Found overflow. Skip step') + self._clear_optim_states() # clear chunk states used for optimizer update + self.zero_grad() # reset all gradients + self._update_fp16_params() + return + + # get combined scale. combined scale = loss scale * clipping norm + # so that gradient = gradient / combined scale + combined_scale = self._get_combined_scale() + self.grad_scaler.update(found_inf) + self._clear_optim_states() + + if not self.scaled_optimizer: + assert combined_scale == -1, 'You should use an optimizer in the available list:\n' \ + f'{_AVAIL_OPTIM_LIST}' + ret = self.optim.step(*args, **kwargs) + else: + ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) + + self.zero_grad() + self._update_fp16_params() + return ret + + def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): + raise NotImplementedError + + def backward(self, loss: torch.Tensor): + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.module.backward(loss) + + def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): + # This function is called except the last stage of pipeline parallel + # It receives the scaled grad from the previous rank + # No need to scale the grad again + # Need to unscale when optimizing + self.optim_state = OptimState.SCALED + self.module.backward_by_grad(tensor, grad) + + def __init__optimizer(self): + + def get_range_pair(local_chunk: Chunk, local_param: Parameter): + param_info = local_chunk.tensors_info[local_param] + begin = max(0, param_info.offset - local_chunk.shard_begin) + end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) + return begin, end + + for group in self.param_groups: + fake_params_list = list() + + for param in group['params']: + if not param.requires_grad: + continue + + param_chunk = self.module.fetcher.get_one_chunk(param) + range_pair = get_range_pair(param_chunk, param) + if range_pair[0] >= range_pair[1]: + continue + + grad_device = param_chunk.shard.device + fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) + self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk + self.param_to_range[fake_param] = range_pair + + fake_params_list.append(fake_param) + + group['params'] = fake_params_list diff --git a/requirements/requirements.txt b/requirements/requirements.txt index b34dc2e22..e1a9e691e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,3 +10,5 @@ contexttimer ninja torch>=1.11 safetensors +sortedcontainers +einops diff --git a/setup.py b/setup.py index 5d8f83121..80bdccca9 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from op_builder.utils import ( try: import torch - from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -30,7 +30,11 @@ BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 # a variable to store the op builder -ext_modules = [] +ext_modules = [ + CppExtension(name='colossalai.elixir.simulator', + sources=['colossalai/elixir/simulator.cpp'], + extra_compile_args=['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']) +] # we do not support windows currently if sys.platform == 'win32': diff --git a/tests/test_elixir/__init__.py b/tests/test_elixir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_elixir/compatibility_check.py b/tests/test_elixir/compatibility_check.py new file mode 100644 index 000000000..1932c831a --- /dev/null +++ b/tests/test_elixir/compatibility_check.py @@ -0,0 +1,89 @@ +import torch +import torch.distributed as dist + +import colossalai +from colossalai.elixir import ElixirModule, ElixirOptimizer +from colossalai.elixir.search import minimum_waste_search +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def check_elixir_compatibility(early_stop: bool = True): + """check gemini plugin over model zoo + + Args: + early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. + """ + passed_models = [] + failed_info = {} # (model_name, error) pair + + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + # These models lead to CUDA error + if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', + 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext', + 'torchaudio_wav2vec2_base', 'torchaudio_hubert_base', 'torchvision_convnext_base'): + continue + + try: + print(name) + global_size = dist.get_world_size() + global_group = dist.GroupMember.WORLD + + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + for k, v in data.items() + } + + sr = minimum_waste_search( + # pre-commit: do not rearrange + m=model, + group_size=global_size, + unified_dtype=torch.float16, + prefetch=False, + verbose=True) + + model = ElixirModule(model, sr, global_group, prefetch=False, dtype=torch.float16) + optimizer = ElixirOptimizer(model, optimizer, initial_scale=32) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + optimizer.backward(loss) + optimizer.step() + passed_models.append(name) + + del model, optimizer, criterion, data, output, loss + except Exception as e: + failed_info[name] = e + if early_stop: + raise e + + torch.cuda.empty_cache() + + if dist.get_rank() == 0: + print(f'Passed models({len(passed_models)}): {passed_models}\n\n') + print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') + assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + + +def run_dist(rank, world_size, port, early_stop: bool = True): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_elixir_compatibility(early_stop=early_stop) + + +@rerun_if_address_is_in_use() +def exam_compatibility(early_stop: bool = True): + spawn(run_dist, 2, early_stop=early_stop) + + +if __name__ == '__main__': + exam_compatibility(early_stop=False) diff --git a/tests/test_elixir/test_chunk/__init__.py b/tests/test_elixir/test_chunk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_elixir/test_chunk/fetcher_utils.py b/tests/test_elixir/test_chunk/fetcher_utils.py new file mode 100644 index 000000000..22caedee6 --- /dev/null +++ b/tests/test_elixir/test_chunk/fetcher_utils.py @@ -0,0 +1,72 @@ +from functools import partial + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState +from colossalai.elixir.chunk.scheduler import FIFOScheduler +from colossalai.elixir.hook import BufferStore, HookParam +from colossalai.elixir.tensor import OutplaceTensor + + +def to_divide(a: int, b: int): + return a + (-a % b) + + +def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher): + empty_grad = torch.empty_like(grad) + empty_grad.storage().resize_(0) + + with torch._C.DisableTorchFunction(): + chunk = fetcher.get_one_chunk(param) + if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD: + raise RuntimeError() + fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE) + chunk.copy_tensor_to_chunk_slice(param, grad) + fetcher.reduce_chunk(chunk) + + return empty_grad + + +def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo): + pg_size = dist.get_world_size(process_group) + + private_list = list() + for param in model.parameters(): + block_size = to_divide(param.numel(), pg_size) + private_list.append(BlockRequire(block_size, param.dtype)) + + mp = MemoryPool('cuda') + mp.allocate(private_block_list=private_list) + cg = ChunkGroup(rcache=mp) + # allocate chunk group + fused_config = dict(rcache_fused=True) + for param in model.parameters(): + cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config) + # initialize chunk fetcher + scheduler = FIFOScheduler() + fetcher = ChunkFetcher(scheduler, cg) + buffer = BufferStore(0, torch.float32) + # register fetcher and gradient handler + HookParam.attach_fetcher(fetcher, buffer) + for param in model.parameters(): + param.register_hook(partial(grad_handler, param=param, fetcher=fetcher)) + param.__class__ = HookParam + # set inplace to False for all modules + for module in model.modules(): + if hasattr(module, 'inplace'): + module.inplace = False + + def transform_input(self_module, inputs): + fetcher.reset() + input_list = list() + for t in inputs: + if isinstance(t, torch.Tensor): + t = OutplaceTensor(t) + input_list.append(t) + return tuple(input_list) + + model.register_forward_pre_hook(transform_input) + + return model, cg diff --git a/tests/test_elixir/test_chunk/test_block.py b/tests/test_elixir/test_chunk/test_block.py new file mode 100644 index 000000000..bec8bf42b --- /dev/null +++ b/tests/test_elixir/test_chunk/test_block.py @@ -0,0 +1,63 @@ +import torch + +from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock +from colossalai.testing import run_on_environment_flag + + +@run_on_environment_flag('ELX') +def test_block(): + b = PublicBlock(123, torch.float16, 'cuda') + payload_b = b.payload + + assert payload_b.numel() == 123 + assert payload_b.dtype == torch.float16 + assert payload_b.device.type == 'cuda' + assert payload_b.numel() * payload_b.element_size() == b.memo_occ + + c = PrivateBlock(77, torch.float, 'cpu') + payload_c = c.payload + + assert payload_c.numel() == 77 + assert payload_c.dtype == torch.float + assert payload_c.device.type == 'cpu' + assert payload_c.numel() * payload_c.element_size() == c.memo_occ + + print('test_block: ok') + + +@run_on_environment_flag('ELX') +def test_memory_pool(): + mp = MemoryPool(device_type='cuda') + private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)] + mp.allocate(public_block_number=4, private_block_list=private_list) + + block0 = mp.get_public_block() + + assert block0 in mp.public_used_blocks + assert mp.public_used_cnt == 1 + assert mp.public_free_cnt == 3 + + block1 = mp.get_public_block() + + assert block1 in mp.public_used_blocks + assert mp.public_used_cnt == 2 + assert mp.public_free_cnt == 2 + + mp.free_public_block(block0) + mp.free_public_block(block1) + + assert block0 in mp.public_free_blocks + assert block1 in mp.public_free_blocks + assert mp.public_used_cnt == 0 + assert mp.public_free_cnt == 4 + + block0 = mp.get_private_block(5, torch.float) + assert block0.numel == 5 + assert block0.dtype == torch.float + + print('test_memory_pool: ok') + + +if __name__ == '__main__': + test_block() + test_memory_pool() diff --git a/tests/test_elixir/test_chunk/test_chunk.py b/tests/test_elixir/test_chunk/test_chunk.py new file mode 100644 index 000000000..f7fb9a0dd --- /dev/null +++ b/tests/test_elixir/test_chunk/test_chunk.py @@ -0,0 +1,155 @@ +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist + +from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState +from colossalai.elixir.utils import init_distributed +from colossalai.testing import run_on_environment_flag + + +def exam_chunk_functions(nproc, group): + a = torch.randn(2, 64, device='cuda') + copy_a = a.clone() + b = torch.randn(2, 2, 128, device='cuda') + copy_b = b.clone() + c = torch.randn(128, device='cuda') + copy_c = c.clone() + d = torch.randn(4, 32, device='cuda') + copy_d = d.clone() + + mp = MemoryPool('cuda') + mp.allocate(public_block_number=1) + + chunk = Chunk(mp, 1024, torch.float, group) + chunk.l2_norm_flag = True + assert chunk.chunk_size == 1024 + assert chunk.chunk_dtype == torch.float + assert chunk.shard_size == 1024 // nproc + + def check_tensors(): + assert torch.equal(a, copy_a) + assert torch.equal(b, copy_b) + assert torch.equal(c, copy_c) + assert torch.equal(d, copy_d) + + chunk.append_tensor(a) + chunk.append_tensor(b) + chunk.append_tensor(c) + chunk.append_tensor(d) + check_tensors() + + chunk.close_chunk() + assert chunk.is_replica is False + # check function: get_cpu_copy + cpu_copys = chunk.get_cpu_copy() + for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys): + assert t_cpu.device.type == 'cpu' + assert torch.equal(t_gpu.cpu(), t_cpu) + # check function: access_chunk + block = mp.get_public_block() + chunk.access_chunk(block) + assert chunk.is_replica + assert chunk.scatter_check + check_tensors() + # check function: release_chunk + chunk.optim_sync_flag = False + block = chunk.release_chunk() + assert block in mp.public_used_blocks + assert chunk.is_replica is False + assert chunk.optim_sync_flag is True + # check function: access_chunk after release_chunk + chunk.access_chunk(block) + check_tensors() + # check function: reduce_chunk + norm = block.payload.float().norm(2)**2 + chunk.reduce_chunk() + assert chunk.is_replica is False + assert chunk.tensor_state_cnter[TensorState.HOLD] == 4 + + test_norm = torch.Tensor([chunk.l2_norm]).cuda() + dist.all_reduce(test_norm) + assert torch.allclose(norm, test_norm) + + torch.cuda.synchronize() + print('chunk functions are ok') + + +def exam_chunk_states(nproc, group): + a = torch.randn(2, 64, device='cuda') + copy_a = a.clone() + b = torch.randn(2, 2, 128, device='cuda') + copy_b = b.clone() + c = torch.randn(128, device='cuda') + copy_c = c.clone() + d = torch.randn(4, 32, device='cuda') + copy_d = d.clone() + + private = [BlockRequire(1024, torch.float)] + mp = MemoryPool('cuda') + mp.allocate(private_block_list=private) + + chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True) + assert chunk.chunk_size == 1024 + assert chunk.chunk_dtype == torch.float + assert chunk.shard_size == 1024 // nproc + + def check_tensors(): + assert torch.equal(a, copy_a) + assert torch.equal(b, copy_b) + assert torch.equal(c, copy_c) + assert torch.equal(d, copy_d) + + chunk.append_tensor(a) + chunk.append_tensor(b) + chunk.append_tensor(c) + chunk.append_tensor(d) + check_tensors() + + chunk.close_chunk() + assert chunk.is_replica is False + + chunk.access_chunk() + assert chunk.is_replica + check_tensors() + + assert chunk.tensor_state_cnter[TensorState.HOLD] == 4 + chunk.tensor_trans_state(a, TensorState.COMPUTE) + assert chunk.tensor_state_cnter[TensorState.HOLD] == 3 + assert chunk.tensor_state_cnter[TensorState.COMPUTE] == 1 + + tensor_list = [a, b, c, d] + for t in tensor_list: + chunk.tensor_trans_state(t, TensorState.COMPUTE) + chunk.tensor_trans_state(t, TensorState.HOLD_AFTER_BWD) + chunk.tensor_trans_state(t, TensorState.READY_FOR_REDUCE) + assert chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4 + assert chunk.reduce_check + + torch.cuda.synchronize() + print('chunk states are ok') + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD) + exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@run_on_environment_flag('ELX') +def test_chunk_functions(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_functions(world_size=4) diff --git a/tests/test_elixir/test_chunk/test_fetcher.py b/tests/test_elixir/test_chunk/test_fetcher.py new file mode 100644 index 000000000..27906b18d --- /dev/null +++ b/tests/test_elixir/test_chunk/test_fetcher.py @@ -0,0 +1,71 @@ +import copy +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +from colossalai.elixir.chunk import ChunkGroup +from colossalai.elixir.utils import init_distributed, seed_all +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.test_chunk.fetcher_utils import hook_transform +from tests.test_elixir.utils import TEST_MODELS, to_cuda + + +def check_gradient(ddp_model, my_model, cg: ChunkGroup): + for chunk in cg.fused_chunks: + cg.access_chunk(chunk) + + for (name, p0), p1 in zip(ddp_model.named_parameters(), my_model.parameters()): + torch.cuda.synchronize() + print(f'checking parameter {name}') + assert_close(p0.grad.data, p1.data) + + +def exam_chunk_fetcher(nproc, group): + model_fn, data_fn = TEST_MODELS.get('resnet') + torch_model = model_fn().cuda() + test_model = copy.deepcopy(torch_model) + + rank = dist.get_rank(group) + # get different data + seed_all(1001 + rank) + data = to_cuda(data_fn()) + + seed_all(1001, cuda_deterministic=True) + ddp_model = DDP(torch_model) + ddp_loss = ddp_model(**data) + ddp_loss.backward() + + hook_model, cg = hook_transform(test_model, group) + my_loss = hook_model(**data) + my_loss.backward() + + assert_close(ddp_loss, my_loss) + check_gradient(ddp_model, hook_model, cg) + print('private chunk fetcher is ok') + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@run_on_environment_flag('ELX') +def test_chunk_fetcher(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_fetcher(world_size=2) diff --git a/tests/test_elixir/test_chunk/test_group.py b/tests/test_elixir/test_chunk/test_group.py new file mode 100644 index 000000000..df183a9aa --- /dev/null +++ b/tests/test_elixir/test_chunk/test_group.py @@ -0,0 +1,98 @@ +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist + +from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState +from colossalai.elixir.utils import init_distributed +from colossalai.testing import run_on_environment_flag + + +def exam_chunk_group_functions(nproc, group): + a = torch.randn(3, 64, device='cuda') + copy_a = a.clone() + b = torch.randn(2, 32, device='cuda') + copy_b = b.clone() + c = torch.randn(256, device='cuda') + copy_c = c.clone() + d = torch.randn(2, 2, 64, device='cuda') + copy_d = d.clone() + e = torch.randn(2, 33, device='cuda') + copy_e = e.clone() + + mp = MemoryPool('cuda') + mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)]) + cg = ChunkGroup(rcache=mp) + c0 = cg.allocate_chunk([a, b], 256, torch.float, group) + c1 = cg.allocate_chunk([c], 256, torch.float, group) + c2 = cg.allocate_chunk([d], 256, torch.float, group) + + fused_config = dict(rcache_fused=True) + c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config) + + def check_chunk_0(): + assert torch.equal(a, copy_a) + assert torch.equal(b, copy_b) + + def check_chunk_1(): + assert torch.equal(c, copy_c) + + def check_chunk_2(): + assert torch.equal(d, copy_d) + + def check_chunk_3(): + assert torch.equal(e, copy_e) + + # check tensors_to_chunks + chunks = cg.tensors_to_chunks([e, a]) + assert chunks[0] == c0 + assert chunks[1] == c3 + # check access_chunk for unfused chunks + cg.access_chunk(c0) + cg.access_chunk(c1) + check_chunk_0() + check_chunk_1() + assert not cg.rcache_enough_check(c2) + assert cg.rcache_enough_check(c3) + # check access_chunk for fused chunks + cg.access_chunk(c3) + check_chunk_3() + # check release_chunk for unfused chunks + cg.release_chunk(c1) + assert cg.rcache_enough_check(c2) + # check access_chunk + cg.access_chunk(c2) + check_chunk_2() + + cg.tensor_trans_state(e, TensorState.COMPUTE) + cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD) + cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE) + cg.reduce_chunk(c3) + assert not c3.is_replica + + torch.cuda.synchronize() + print('chunk group functions are ok') + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@run_on_environment_flag('ELX') +def test_chunk_group(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_group(world_size=2) diff --git a/tests/test_elixir/test_chunk/test_scheduler.py b/tests/test_elixir/test_chunk/test_scheduler.py new file mode 100644 index 000000000..d0e5a0f47 --- /dev/null +++ b/tests/test_elixir/test_chunk/test_scheduler.py @@ -0,0 +1,130 @@ +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist + +from colossalai.elixir.chunk import Chunk, MemoryPool +from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler +from colossalai.elixir.utils import init_distributed +from colossalai.testing import run_on_environment_flag + + +def exam_fifo(nproc, group): + mp = MemoryPool('cuda') + mp.allocate(public_block_number=1) + c0 = Chunk(mp, 1024, torch.float, group) + c1 = Chunk(mp, 1024, torch.float, group) + c2 = Chunk(mp, 1024, torch.float, group) + + sdl = FIFOScheduler() + sdl.reset() + + sdl.add(c0) + sdl.add(c1) + sdl.add(c2) + sdl.add(c0) # nothing happens here + assert sdl.top() == c0 + + sdl.remove(c0) + assert sdl.top() == c1, f'{sdl.top()}' + sdl.remove(c0) + assert sdl.top() == c1, f'{sdl.top()}' + + sdl.add(c0) + assert sdl.top() == c1 + sdl.remove(c1) + assert sdl.top() == c2 + sdl.remove(c2) + assert sdl.top() == c0 + + +def exam_prefetch(nproc, group): + mp = MemoryPool('cuda') + mp.allocate() + c0 = Chunk(mp, 1024, torch.float, group) + c1 = Chunk(mp, 1024, torch.float, group) + c2 = Chunk(mp, 1024, torch.float, group) + + chunk_called_per_step = [[c0], [c1], [c2], [c0], [c0], [c1], [c2], [c2], [c1], [c0]] + + sdl = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step) + print(sdl.next_step_dict) + sdl.reset() + + sdl.step() + sdl.add(c0) + assert sdl.top() == c0 + + sdl.step() + sdl.add(c1) + assert sdl.top() == c1 + + sdl.step() + sdl.add(c2) + assert sdl.top() == c2 + + sdl.remove(c0) + sdl.step() + sdl.add(c0) + assert sdl.top() == c2 + + sdl.remove(c0) + sdl.step() + sdl.add(c0) + assert sdl.top() == c0 + sdl.remove(c0) # notice here + + sdl.remove(c1) + sdl.step() + sdl.add(c1) + assert sdl.top() == c1 + + sdl.remove(c2) + sdl.step() + sdl.add(c2) + assert sdl.top() == c1 + + sdl.remove(c2) + sdl.step() + sdl.add(c2) + assert sdl.top() == c2 + sdl.remove(c2) # notice here + sdl.add(c0) # notice here + + sdl.remove(c1) + sdl.step() + sdl.add(c1) + assert sdl.top() == c1 + sdl.remove(c1) # notice here + + sdl.remove(c0) + sdl.step() + sdl.add(c0) + assert sdl.top() == c0 + + sdl.remove(c0) + sdl.clear() + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD) + exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@run_on_environment_flag('ELX') +def test_chunk_scheduler(world_size=1): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_chunk_scheduler() diff --git a/tests/test_elixir/test_ctx/test_meta_ctx.py b/tests/test_elixir/test_ctx/test_meta_ctx.py new file mode 100644 index 000000000..99d4ab1ec --- /dev/null +++ b/tests/test_elixir/test_ctx/test_meta_ctx.py @@ -0,0 +1,22 @@ +from colossalai.elixir.ctx import MetaContext +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS + + +@run_on_environment_flag('ELX') +def test_meta_context(): + builder, *_ = TEST_MODELS.get('resnet') + with MetaContext(): + model = builder() + + for name, param in model.named_parameters(): + assert param.device.type == 'meta' + print(name, param) + + for name, buffer in model.named_buffers(): + assert buffer.device.type == 'meta' + print(name, buffer) + + +if __name__ == '__main__': + test_meta_context() diff --git a/tests/test_elixir/test_hook.py b/tests/test_elixir/test_hook.py new file mode 100644 index 000000000..d2f5df6c4 --- /dev/null +++ b/tests/test_elixir/test_hook.py @@ -0,0 +1,56 @@ +from copy import deepcopy + +import torch +import torch.nn as nn + +from colossalai.elixir.hook import BufferStore, HookParam +from colossalai.elixir.tensor import FakeTensor + + +def test_hook(): + x = nn.Parameter(torch.randn(4, 4)) + + ori_numel = x.numel() + ori_size = x.size() + ori_stride = x.stride() + ori_offset = x.storage_offset() + + fake_data = FakeTensor(x.data) + x.data = fake_data + x.__class__ = HookParam + + assert x.numel() == ori_numel + assert x.size() == ori_size + assert x.stride() == ori_stride + assert x.storage_offset() == ori_offset + + +def test_store(): + buffer = BufferStore(1024, torch.float16) + print(buffer) + + x = torch.randn(4, 128, dtype=torch.float16, device='cuda') + original_ptr_x = x.data_ptr() + copy_x = deepcopy(x) + + y = torch.randn(512, dtype=torch.float16, device='cuda') + original_ptr_y = y.data_ptr() + copy_y = deepcopy(y) + + offset = 0 + offset = buffer.insert(x, offset) + assert offset == x.numel() + assert torch.equal(x, copy_x) + + offset = buffer.insert(y, offset) + assert offset == 1024 + assert torch.equal(y, copy_y) + + buffer.erase(x) + buffer.erase(y) + assert x.data_ptr() == original_ptr_x + assert y.data_ptr() == original_ptr_y + + +if __name__ == '__main__': + test_store() diff --git a/tests/test_elixir/test_kernels/test_attn.py b/tests/test_elixir/test_kernels/test_attn.py new file mode 100644 index 000000000..0fb91acec --- /dev/null +++ b/tests/test_elixir/test_kernels/test_attn.py @@ -0,0 +1,35 @@ +from copy import deepcopy + +import pytest +from torch.testing import assert_close + +from tests.test_elixir.utils import TEST_MODELS, to_cuda + + +def exam_one_model(model_fn, data_fn): + from colossalai.elixir.kernels.attn_wrapper import wrap_attention + + torch_model = model_fn().cuda() + test_model = deepcopy(torch_model) + test_model = wrap_attention(test_model) + + data = to_cuda(data_fn()) + torch_out = torch_model(**data) + torch_out.backward() + + test_out = test_model(**data) + test_out.backward() + + assert_close(torch_out, test_out) + for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()): + assert_close(p_torch.grad, p_test.grad) + + +@pytest.mark.skip(reason="Need to install xformers") +def test_gpt_atten_kernel(): + exam_one_model(*TEST_MODELS.get('gpt2_micro')) + exam_one_model(*TEST_MODELS.get('opt_micro')) + + +if __name__ == '__main__': + test_gpt_atten_kernel() diff --git a/tests/test_elixir/test_kernels/test_ln.py b/tests/test_elixir/test_kernels/test_ln.py new file mode 100644 index 000000000..5edeeb710 --- /dev/null +++ b/tests/test_elixir/test_kernels/test_ln.py @@ -0,0 +1,58 @@ +import os +from copy import deepcopy +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +from colossalai.elixir.search import simple_search +from colossalai.elixir.utils import init_distributed +from colossalai.elixir.wrapper import ElixirModule + + +def exam_fused_layernorm(nproc, group): + torch_model = nn.LayerNorm(2048) + fused_model = deepcopy(torch_model) + + torch_model = torch_model.cuda() + sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True) + fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True) + + data = torch.randn(2, 2048, device='cuda') + + torch_loss = torch_model(data).sum() + torch_loss.backward() + + fused_loss = fused_model(data).sum() + fused_model.backward(fused_loss) + + assert_close(torch_loss, fused_loss) + + grad_state = fused_model.state_dict(from_param=True) + for name, param in torch_model.named_parameters(): + assert_close(param.grad.cpu(), grad_state[name]) + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1]) +@pytest.mark.skip(reason='need to install apex') +def test_fused_layernorm(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_fused_layernorm(world_size=1) diff --git a/tests/test_elixir/test_search/test_mini_waste.py b/tests/test_elixir/test_search/test_mini_waste.py new file mode 100644 index 000000000..ebde6e0e7 --- /dev/null +++ b/tests/test_elixir/test_search/test_mini_waste.py @@ -0,0 +1,38 @@ +from copy import deepcopy + +import torch + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.search import minimum_waste_search +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS + + +def step_fn(model, inp): + model(**inp).backward() + + +@run_on_environment_flag('ELX') +def test_mini_waste_search(): + model_fn, data_fn = TEST_MODELS.get('gpt2_small') + model = model_fn() + data = data_fn() + + sr = minimum_waste_search(model, + 1, + unified_dtype=torch.float16, + cpu_offload=True, + prefetch=True, + verbose=True, + inp=data, + step_fn=step_fn) + + chunk_plans = deepcopy(sr.param_chunk_plans) + for plan in chunk_plans: + assert plan.chunk_dtype == torch.float16 + assert plan.kwargs.get('shard_device') == torch.device('cpu') + assert plan.kwargs.get('cpu_pin_memory') == True + + +if __name__ == '__main__': + test_mini_waste_search() diff --git a/tests/test_elixir/test_search/test_optimal.py b/tests/test_elixir/test_search/test_optimal.py new file mode 100644 index 000000000..56898196c --- /dev/null +++ b/tests/test_elixir/test_search/test_optimal.py @@ -0,0 +1,30 @@ +from copy import deepcopy + +import torch + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.search import optimal_search +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS + + +def step_fn(model, inp): + model(**inp).backward() + + +@run_on_environment_flag('ELX') +def test_optimal_search(): + model_fn, data_fn = TEST_MODELS.get('gpt2_small') + model = model_fn() + data = data_fn() + + sr = optimal_search(model, 1, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=step_fn) + + chunk_plans = deepcopy(sr.param_chunk_plans) + for plan in chunk_plans: + assert plan.chunk_dtype == torch.float16 + assert plan.kwargs.get('shard_device') == gpu_device() + + +if __name__ == '__main__': + test_optimal_search() diff --git a/tests/test_elixir/test_search/test_simple.py b/tests/test_elixir/test_search/test_simple.py new file mode 100644 index 000000000..52a2cbc1c --- /dev/null +++ b/tests/test_elixir/test_search/test_simple.py @@ -0,0 +1,48 @@ +from copy import deepcopy + +import torch + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.search import simple_search +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS + + +def step_fn(model, inp): + model(**inp).backward() + + +@run_on_environment_flag('ELX') +def test_simple_search(): + model_fn, data_fn = TEST_MODELS.get('small') + model = model_fn() + data = data_fn() + + sr = simple_search(model, + 1, + split_number=5, + shard_device=gpu_device(), + prefetch=True, + verbose=True, + inp=data, + step_fn=step_fn) + + chunk_plans = deepcopy(sr.param_chunk_plans) + private_plan = chunk_plans.pop(0) + assert private_plan.name_list == ['embed.weight'] + assert private_plan.chunk_size == 320 + assert private_plan.kwargs.get('shard_device') == gpu_device() + + assert chunk_plans[0].name_list == ['norm1.weight', 'norm1.bias'] + assert chunk_plans[1].name_list == ['mlp.proj1.weight', 'mlp.proj1.bias'] + assert chunk_plans[2].name_list == ['mlp.proj2.weight', 'mlp.proj2.bias'] + assert chunk_plans[3].name_list == ['norm2.weight'] + assert chunk_plans[4].name_list == ['norm2.bias'] + + for plan in chunk_plans: + assert plan.chunk_size == 1088 + assert plan.kwargs.get('shard_device') == gpu_device() + + +if __name__ == '__main__': + test_simple_search() diff --git a/tests/test_elixir/test_src/test_move.py b/tests/test_elixir/test_src/test_move.py new file mode 100644 index 000000000..74a7cd31a --- /dev/null +++ b/tests/test_elixir/test_src/test_move.py @@ -0,0 +1,13 @@ +from colossalai.elixir.simulator import move_count +from colossalai.testing import run_on_environment_flag + + +@run_on_environment_flag('ELX') +def test_move_count(): + steps = [[0], [1, 2], [3], [3], [1, 2], [0]] + size = 2 + assert move_count(steps, size) == 12 + + +if __name__ == '__main__': + test_move_count() diff --git a/tests/test_elixir/test_tools/test_registry.py b/tests/test_elixir/test_tools/test_registry.py new file mode 100644 index 000000000..8530d8cbb --- /dev/null +++ b/tests/test_elixir/test_tools/test_registry.py @@ -0,0 +1,23 @@ +import pytest +import torch + +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import to_cuda + + +@run_on_environment_flag('ELX') +def test_registry(): + from tests.test_elixir.utils.registry import TEST_MODELS + for name, model_tuple in TEST_MODELS: + torch.cuda.synchronize() + print(f'model `{name}` is in testing') + + model_fn, data_fn = model_tuple + model = model_fn().cuda() + data = to_cuda(data_fn()) + loss = model(**data) + loss.backward() + + +if __name__ == '__main__': + test_registry() diff --git a/tests/test_elixir/test_tracer/test_cuda_profiler.py b/tests/test_elixir/test_tracer/test_cuda_profiler.py new file mode 100644 index 000000000..994abc5ea --- /dev/null +++ b/tests/test_elixir/test_tracer/test_cuda_profiler.py @@ -0,0 +1,46 @@ +import pytest +import torch + +from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS, to_cuda + + +def one_step(model, inp): + loss = model(**inp) + loss.backward() + return loss + + +def try_one_model(model_fn, data_fn): + model = model_fn().cuda() + data = to_cuda(data_fn()) + one_step(model, data) # generate gradients + + pre_cuda_alc = torch.cuda.memory_allocated() + torch.cuda.reset_peak_memory_stats() + one_step(model, data) + aft_cuda_alc = torch.cuda.max_memory_allocated() + torch_activation_occ = aft_cuda_alc - pre_cuda_alc + model.zero_grad(set_to_none=True) + print('normal', torch_activation_occ) + + before = torch.cuda.memory_allocated() + profiling_dict = cuda_memory_profiling(model, data, one_step) + after = torch.cuda.memory_allocated() + print('profiling', profiling_dict) + assert before == after + assert torch_activation_occ == profiling_dict['activation_occ'] + print('Check is ok.') + + +@run_on_environment_flag('ELX') +def test_cuda_profiler(): + model_list = ['resnet', 'gpt2_micro'] + for name in model_list: + model_fn, data_fn = TEST_MODELS.get(name) + try_one_model(model_fn, data_fn) + + +if __name__ == '__main__': + test_cuda_profiler() diff --git a/tests/test_elixir/test_tracer/test_op_cache.py b/tests/test_elixir/test_tracer/test_op_cache.py new file mode 100644 index 000000000..cfc7cd129 --- /dev/null +++ b/tests/test_elixir/test_tracer/test_op_cache.py @@ -0,0 +1,145 @@ +import pytest +import torch + +from colossalai.elixir.tracer.memory_tracer import MTensor +from colossalai.elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache +from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated +from colossalai.testing import run_on_environment_flag + + +def op_mm(x, y): + u = torch.matmul(x, y) + return u.shape + + +def op_addmm(x, y, z): + u = torch.addmm(x, y, z) + return u.shape + + +def op_bmm(x, y): + u = torch.bmm(x, y) + return u.shape + + +@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16]) +@run_on_environment_flag('ELX') +def test_mm(dtype, size0=(4, 256), size1=(256, 1024)): + torch.cuda.reset_peak_memory_stats() + assert get_cuda_allocated() == 0 + + x = torch.randn(size0, dtype=dtype, device='cuda') + y = torch.randn(size1, dtype=dtype, device='cuda') + torch_pre_alc = get_cuda_allocated() + + torch_z_size = op_mm(x, y) + torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc + + del x + del y + + assert get_cuda_allocated() == 0 + x = MTensor(torch.randn(size0, dtype=dtype, device='cuda')) + y = MTensor(torch.randn(size1, dtype=dtype, device='cuda')) + op1_pre_alc = get_cuda_allocated() + + MTensor.reset_peak_memory() + op1_z_size = op_mm(x, y) + op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc + + assert torch_z_size == op1_z_size + assert torch_pre_alc == op1_pre_alc + assert torch_temp_alc == op1_temp_alc + assert len(mm_cache.temp_memory) > 0 + + MTensor.reset_peak_memory() + op2_z_size = op_mm(x, y) + op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc + + assert torch_z_size == op2_z_size + assert torch_temp_alc == op2_temp_alc + + +@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16]) +@run_on_environment_flag('ELX') +def test_addmm(dtype, size0=(4, 16), size1=(16, 64)): + torch.cuda.reset_peak_memory_stats() + assert get_cuda_allocated() == 0 + + x = torch.randn(size0, dtype=dtype, device='cuda') + y = torch.randn(size1, dtype=dtype, device='cuda') + u = torch.randn(size1[-1], dtype=dtype, device='cuda') + torch_pre_alc = get_cuda_allocated() + + torch_z_size = op_addmm(u, x, y) + torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc + + del x + del y + del u + + assert get_cuda_allocated() == 0 + x = MTensor(torch.randn(size0, dtype=dtype, device='cuda')) + y = MTensor(torch.randn(size1, dtype=dtype, device='cuda')) + u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda')) + op1_pre_alc = get_cuda_allocated() + + MTensor.reset_peak_memory() + op1_z_size = op_addmm(u, x, y) + op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc + + assert torch_z_size == op1_z_size + assert torch_pre_alc == op1_pre_alc + assert torch_temp_alc == op1_temp_alc + assert len(addmm_cache.temp_memory) > 0 + + MTensor.reset_peak_memory() + op2_z_size = op_addmm(u, x, y) + op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc + + assert torch_z_size == op2_z_size + assert torch_temp_alc == op2_temp_alc + + +@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16]) +@run_on_environment_flag('ELX') +def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)): + torch.cuda.reset_peak_memory_stats() + assert get_cuda_allocated() == 0 + + x = torch.randn(size0, dtype=dtype, device='cuda') + y = torch.randn(size1, dtype=dtype, device='cuda') + torch_pre_alc = get_cuda_allocated() + + torch_z_size = op_bmm(x, y) + torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc + + del x + del y + + assert get_cuda_allocated() == 0 + x = MTensor(torch.randn(size0, dtype=dtype, device='cuda')) + y = MTensor(torch.randn(size1, dtype=dtype, device='cuda')) + op1_pre_alc = get_cuda_allocated() + + MTensor.reset_peak_memory() + op1_z_size = op_bmm(x, y) + op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc + + assert torch_z_size == op1_z_size + assert torch_pre_alc == op1_pre_alc + assert torch_temp_alc == op1_temp_alc + assert len(bmm_cache.temp_memory) > 0 + + bmm_cache.print() + + MTensor.reset_peak_memory() + op2_z_size = op_bmm(x, y) + op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc + + assert torch_z_size == op2_z_size + assert torch_temp_alc == op2_temp_alc + + +if __name__ == '__main__': + test_addmm(dtype=torch.float) diff --git a/tests/test_elixir/test_tracer/test_tf_order.py b/tests/test_elixir/test_tracer/test_tf_order.py new file mode 100644 index 000000000..66daa82f5 --- /dev/null +++ b/tests/test_elixir/test_tracer/test_tf_order.py @@ -0,0 +1,37 @@ +from colossalai.elixir.tracer.param_tracer import generate_tf_order +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS + + +@run_on_environment_flag('ELX') +def test_tf_forward_backward(): + model_fn, data_fn = TEST_MODELS.get('gpt2_micro') + model = model_fn() + data = data_fn() + + def forward_backward_fn(local_model, local_input): + local_model(**local_input).backward() + + # model.gradient_checkpointing_enable() + tf_order = generate_tf_order(model, data, forward_backward_fn) + params_per_step = tf_order['params_per_step'] + assert len(params_per_step) == 32 + + model.gradient_checkpointing_enable() + tf_order = generate_tf_order(model, data, forward_backward_fn) + params_per_step = tf_order['params_per_step'] + checkpoint_info = tf_order['checkpoint_info'] + for i, step in enumerate(params_per_step): + print(f'step {i}: {step}') + for c in checkpoint_info: + print(f'checkpoint info: {c}') + assert len(params_per_step) == 44 + + assert data['input_ids'].device.type == 'cpu' + assert data['attention_mask'].device.type == 'cpu' + for param in model.parameters(): + assert param.device.type == 'cpu' + + +if __name__ == '__main__': + test_tf_forward_backward() diff --git a/tests/test_elixir/test_wrapper/test_amp.py b/tests/test_elixir/test_wrapper/test_amp.py new file mode 100644 index 000000000..eca7f796e --- /dev/null +++ b/tests/test_elixir/test_wrapper/test_amp.py @@ -0,0 +1,94 @@ +import copy +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from apex import amp +from apex.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.search import simple_search +from colossalai.elixir.utils import init_distributed, seed_all +from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS, to_cuda + + +def amp_check_model_states(ddp_optim, test_model): + test_states = test_model.state_dict() + for (name, _), p in zip(test_model.module.named_parameters(), amp.master_params(ddp_optim)): + test_p = test_states[name] + copy_p = p.to(test_p.device) + print(f'checking parameter `{name}`: {test_p.dtype} {copy_p.dtype}') + assert_close(test_p.data, copy_p.data) + + +def exam_amp_one_model(model_fn, data_fn, nproc, group, exam_seed=2261): + ddp_model = model_fn().cuda() + test_model = copy.deepcopy(ddp_model) + # important here, since apex has a lazy fp32 init after the first optimizer step + test_model = test_model.half() + + ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0) + ddp_model, ddp_optim = amp.initialize(ddp_model, + ddp_optim, + opt_level='O2', + loss_scale=1.0, + keep_batchnorm_fp32=False) + ddp_model = DDP(ddp_model, message_size=0, allreduce_always_fp32=True) + print("ok") + exit(0) + test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0) + sr = simple_search(test_model, nproc, shard_device=gpu_device(), unified_dtype=torch.float16, verbose=True) + test_model = ElixirModule(test_model, sr, group, dtype=torch.float16, reduce_always_fp32=True, output_fp32=True) + test_optim = ElixirOptimizer(test_model, test_optim, initial_scale=1.0) + + # get different data + seed_all(exam_seed + dist.get_rank(group), cuda_deterministic=True) + for _ in range(2): + data = to_cuda(data_fn()) + + ddp_optim.zero_grad() + ddp_loss = ddp_model(**data) + with amp.scale_loss(ddp_loss, ddp_optim) as scaled_loss: + scaled_loss.backward() + ddp_optim.step() + + test_optim.zero_grad() + test_loss = test_model(**data) + test_optim.backward(test_loss) + test_optim.step() + + assert_close(ddp_loss, test_loss) + amp_check_model_states(ddp_optim, test_model) + + +def exam_amp_in_models(nproc, group): + model_fn, data_fn = TEST_MODELS.get('gpt2_micro') + exam_amp_one_model(model_fn, data_fn, nproc, group) + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_amp_in_models(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@run_on_environment_flag('ELX') +def test_elixir_amp(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_elixir_amp(world_size=2) diff --git a/tests/test_elixir/test_wrapper/test_module.py b/tests/test_elixir/test_wrapper/test_module.py new file mode 100644 index 000000000..800dc0bda --- /dev/null +++ b/tests/test_elixir/test_wrapper/test_module.py @@ -0,0 +1,95 @@ +import copy +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +from colossalai.elixir.search import simple_search +from colossalai.elixir.utils import init_distributed, seed_all +from colossalai.elixir.wrapper import ElixirModule +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS, assert_dict_values, to_cuda + + +def check_gradient(ddp_model: nn.Module, test_model: ElixirModule): + grad_state = test_model.state_dict(from_param=True) + for name, param in ddp_model.named_parameters(): + assert_close(param.grad.cpu(), grad_state[name]) + + +def exam_module_init(nproc, group, grad_flag): + model_fn, data_fn = TEST_MODELS.get('resnet') + torch_model = model_fn().cuda() + test_model = model_fn().cuda() + + for p1, p2 in zip(torch_model.parameters(), test_model.parameters()): + p1.requires_grad = p2.requires_grad = grad_flag + + sr = simple_search(test_model, nproc) + model = ElixirModule(test_model, sr, group) + # check function: ElixirModule.load_state_dict after ElixirModule.__init__ + torch_st = torch_model.state_dict() + if dist.get_rank() != 0: + torch_st = None + test_st = model.load_state_dict(torch_st, only_rank_0=True) + # check function: ElixirModule.state_dict after ElixirModule.__init__ + torch_st = torch_model.state_dict() + test_st = model.state_dict() + assert_dict_values(torch_st, test_st, fn=torch.equal) + + +def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2261): + ddp_model = model_fn().cuda() + test_model = copy.deepcopy(ddp_model) + sr = simple_search(test_model, nproc, allocate_factor=0.6) + test_model = ElixirModule(test_model, sr, group) + + # get different data + seed_all(exam_seed + dist.get_rank(group)) + data = data_fn() + data = to_cuda(data) + + seed_all(exam_seed, cuda_deterministic=True) + ddp_model = DDP(ddp_model) + ddp_loss = ddp_model(**data) + ddp_loss.backward() + + test_loss = test_model(**data) + test_model.backward(test_loss) + + assert_close(ddp_loss, test_loss) + check_gradient(ddp_model.module, test_model) + + +def exam_modules_fwd_bwd(nproc, group): + model_fn, data_fn = TEST_MODELS.get('resnet') + exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group) + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=False) + exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=True) + exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@run_on_environment_flag('ELX') +def test_elixir_module(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_elixir_module(world_size=2) diff --git a/tests/test_elixir/test_wrapper/test_optimizer.py b/tests/test_elixir/test_wrapper/test_optimizer.py new file mode 100644 index 000000000..8588d62a2 --- /dev/null +++ b/tests/test_elixir/test_wrapper/test_optimizer.py @@ -0,0 +1,77 @@ +import copy +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.search import simple_search +from colossalai.elixir.utils import init_distributed, seed_all +from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS, allclose, assert_dict_values, to_cuda + + +def exam_optimizer_one_model(model_fn, data_fn, nproc, group, exam_seed=2261): + ddp_model = model_fn().cuda() + test_model = copy.deepcopy(ddp_model) + + ddp_model = DDP(ddp_model) + ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0) + + test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0) + sr = simple_search(test_model, nproc, shard_device=gpu_device()) + test_model = ElixirModule(test_model, sr, group) + test_optim = ElixirOptimizer(test_model, test_optim) + + # get different data + seed_all(exam_seed + dist.get_rank(group)) + data = to_cuda(data_fn()) + + seed_all(exam_seed, cuda_deterministic=True) + ddp_optim.zero_grad() + ddp_loss = ddp_model(**data) + ddp_loss.backward() + ddp_optim.step() + + test_optim.zero_grad() + test_loss = test_model(**data) + test_optim.backward(test_loss) + test_optim.step() + + assert_close(ddp_loss, test_loss) + torch_st = ddp_model.module.state_dict() + test_st = test_model.state_dict() + assert_dict_values(torch_st, test_st, fn=partial(allclose, rtol=2e-6, atol=2e-5)) + + +def exam_optimizer_in_models(nproc, group): + model_fn, data_fn = TEST_MODELS.get('resnet') + exam_optimizer_one_model(model_fn, data_fn, nproc, group) + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_optimizer_in_models(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@run_on_environment_flag('ELX') +def test_elixir_optimizer(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_elixir_optimizer(world_size=4) diff --git a/tests/test_elixir/test_wrapper/test_prefetch.py b/tests/test_elixir/test_wrapper/test_prefetch.py new file mode 100644 index 000000000..395b737f3 --- /dev/null +++ b/tests/test_elixir/test_wrapper/test_prefetch.py @@ -0,0 +1,89 @@ +import copy +import os +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +from colossalai.elixir.cuda import gpu_device +from colossalai.elixir.search import simple_search +from colossalai.elixir.utils import init_distributed, seed_all +from colossalai.elixir.wrapper import ElixirModule +from colossalai.testing import run_on_environment_flag +from tests.test_elixir.utils import TEST_MODELS, to_cuda + + +def check_gradient(ddp_model: nn.Module, test_model: ElixirModule): + grad_state = test_model.state_dict(from_param=True) + for name, param in ddp_model.named_parameters(): + assert_close(param.grad.cpu(), grad_state[name]) + + +def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2263): + + def one_step(local_model, local_input): + loss = local_model(**local_input) + loss.backward() + return loss + + ddp_model = model_fn().cuda() + test_model = copy.deepcopy(ddp_model) + + # get different data + seed_all(exam_seed + dist.get_rank(group)) + data = to_cuda(data_fn()) + + # wrap as DDP model + ddp_model = DDP(ddp_model) + # search how to initialize chunks + sr = simple_search(test_model, + nproc, + shard_device=gpu_device(), + prefetch=True, + verbose=True, + inp=data, + step_fn=one_step) + test_model = ElixirModule(test_model, sr, group, prefetch=True) + + seed_all(exam_seed, cuda_deterministic=True) + ddp_loss = one_step(ddp_model, data) + + with torch.no_grad(): + test_loss = test_model(**data) + assert_close(ddp_loss, test_loss) + + test_loss = test_model(**data) + test_model.backward(test_loss) + assert_close(ddp_loss, test_loss) + check_gradient(ddp_model.module, test_model) + + +def exam_modules_fwd_bwd(nproc, group): + model_fn, data_fn = TEST_MODELS.get('resnet') + exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group) + + +def run_dist(rank, world_size): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(29512) + init_distributed() + exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@run_on_environment_flag('ELX') +def test_module_prefetch(world_size): + run_func = partial(run_dist, world_size=world_size) + torch.multiprocessing.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_module_prefetch(world_size=2) diff --git a/tests/test_elixir/utils/__init__.py b/tests/test_elixir/utils/__init__.py new file mode 100644 index 000000000..d68a6ab25 --- /dev/null +++ b/tests/test_elixir/utils/__init__.py @@ -0,0 +1,41 @@ +import torch +from torch.testing import assert_close +from torch.utils._pytree import tree_map + +from . import gpt, mlp, opt, resnet, small +from .registry import TEST_MODELS + + +def to_cuda(input_dict): + + def local_fn(t): + if isinstance(t, torch.Tensor): + t = t.cuda() + return t + + ret = tree_map(local_fn, input_dict) + return ret + + +def allclose(ta, tb, **kwargs): + assert_close(ta, tb, **kwargs) + return True + + +def assert_dict_keys(test_dict, keys): + assert len(test_dict) == len(keys) + for k in keys: + assert k in test_dict + + +def assert_dict_values(da, db, fn): + assert len(da) == len(db) + for k, v in da.items(): + assert k in db + if not torch.is_tensor(v): + continue + u = db.get(k) + if u.device != v.device: + v = v.to(u.device) + # print(f"checking key {k}: {u.shape} vs {v.shape}") + assert fn(u.data, v.data), f'max diff {torch.max(torch.abs(u.data - v.data))}' diff --git a/tests/test_elixir/utils/gpt.py b/tests/test_elixir/utils/gpt.py new file mode 100644 index 000000000..f603ccda9 --- /dev/null +++ b/tests/test_elixir/utils/gpt.py @@ -0,0 +1,79 @@ +from functools import partial + +import torch +import torch.nn as nn +from transformers import GPT2Config, GPT2LMHeadModel + +from tests.test_elixir.utils.registry import TEST_MODELS + +MICRO_VS = 128 +MICRO_BS = 4 +MICRO_SL = 64 + +MACRO_VS = 50257 +MACRO_BS = 2 +MACRO_SL = 1024 + + +def micro_data_fn(): + input_ids = torch.randint(low=0, high=MICRO_VS, size=(MICRO_BS, MICRO_SL)) + attn_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attn_mask) + + +def small_data_fn(): + input_ids = torch.randint(low=0, high=MACRO_VS, size=(MACRO_BS, MACRO_SL)) + attn_mask = torch.ones_like(input_ids) + return dict(input_ids=input_ids, attention_mask=attn_mask) + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +class GPTLMModel(nn.Module): + + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): + super().__init__() + self.enable_gc = False + self.config = GPT2Config( + # pre-commit: do not rearrange + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0) + self.module = GPT2LMHeadModel(config=self.config) + self.criterion = GPTLMLoss() + + def gradient_checkpointing_enable(self): + self.module.gradient_checkpointing_enable() + self.enable_gc = True + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + output = self.module(input_ids=input_ids, attention_mask=attention_mask, use_cache=(not self.enable_gc))[0] + loss = self.criterion(output, input_ids) + return loss + + +gpt2_micro = partial(GPTLMModel, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128) +gpt2_small = GPTLMModel +gpt2_base = partial(GPTLMModel, hidden_size=1024, num_layers=24, num_attention_heads=16) + +TEST_MODELS.register('gpt2_micro', gpt2_micro, micro_data_fn) +TEST_MODELS.register('gpt2_small', gpt2_small, small_data_fn) +TEST_MODELS.register('gpt2_base', gpt2_base, small_data_fn) diff --git a/tests/test_elixir/utils/mlp.py b/tests/test_elixir/utils/mlp.py new file mode 100644 index 000000000..49369ee17 --- /dev/null +++ b/tests/test_elixir/utils/mlp.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn + +from tests.test_elixir.utils.registry import TEST_MODELS + + +def mlp_data_fn(): + return dict(x=torch.randn(4, 16)) + + +class MlpModule(nn.Module): + + def __init__(self, hidden_dim: int = 16) -> None: + super().__init__() + self.proj1 = nn.Linear(hidden_dim, 4 * hidden_dim) + self.act = nn.GELU() + self.proj2 = nn.Linear(4 * hidden_dim, hidden_dim) + + def forward(self, x): + return x + (self.proj2(self.act(self.proj1(x)))) + + +class MlpModel(nn.Module): + + def __init__(self, hidden_dim: int = 16) -> None: + super().__init__() + self.mlp = MlpModule(hidden_dim) + + def forward(self, x): + output = self.mlp(x) + return output.sum() + + +TEST_MODELS.register('mlp', MlpModel, mlp_data_fn) diff --git a/tests/test_elixir/utils/opt.py b/tests/test_elixir/utils/opt.py new file mode 100644 index 000000000..3622f9c40 --- /dev/null +++ b/tests/test_elixir/utils/opt.py @@ -0,0 +1,46 @@ +import torch.nn as nn +from transformers import OPTConfig, OPTForCausalLM + +from tests.test_elixir.utils.registry import TEST_MODELS + +from .gpt import micro_data_fn + + +class OPTLMModel(nn.Module): + + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.module = OPTForCausalLM(config=config) + self.enable_gc = False + + def gradient_checkpointing_enable(self): + self.module.gradient_checkpointing_enable() + self.enable_gc = True + + def forward(self, input_ids, attention_mask): + loss = self.module( + # pre-commit: do not rearrange + input_ids=input_ids, + attention_mask=attention_mask, + labels=input_ids, + use_cache=(not self.enable_gc))['loss'] + return loss + + +def opt_micro(): + opt_config = OPTConfig( + # pre-commit: do not rearrange + vocab_size=128, + activation_dropout=0.0, + dropout=0, + hidden_size=32, + num_hidden_layers=4, + ffn_dim=128, + num_attention_heads=4, + word_embed_proj_dim=32, + output_projection=True) + return OPTLMModel(opt_config) + + +TEST_MODELS.register('opt_micro', opt_micro, micro_data_fn) diff --git a/tests/test_elixir/utils/registry.py b/tests/test_elixir/utils/registry.py new file mode 100644 index 000000000..ac17890b6 --- /dev/null +++ b/tests/test_elixir/utils/registry.py @@ -0,0 +1,26 @@ +from collections import OrderedDict +from typing import Callable + + +class Registry(object): + + def __init__(self) -> None: + super().__init__() + self._registry_dict = OrderedDict() + + def register(self, name: str, model_fn: Callable, data_fn: Callable): + assert name not in self._registry_dict + + model_tuple = (model_fn, data_fn) + self._registry_dict[name] = model_tuple + + def get(self, name: str): + return self._registry_dict[name] + + def __iter__(self): + return iter(self._registry_dict.items()) + + +TEST_MODELS = Registry() + +__all__ = [TEST_MODELS] diff --git a/tests/test_elixir/utils/resnet.py b/tests/test_elixir/utils/resnet.py new file mode 100644 index 000000000..7de3196df --- /dev/null +++ b/tests/test_elixir/utils/resnet.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn +from torchvision.models import resnet18 + +from tests.test_elixir.utils.registry import TEST_MODELS + + +def resnet_data_fn(): + return dict(x=torch.randn(4, 3, 32, 32)) + + +class ResNetModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.r = resnet18() + + def forward(self, x): + output = self.r(x) + return output.sum() + + +TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn) diff --git a/tests/test_elixir/utils/small.py b/tests/test_elixir/utils/small.py new file mode 100644 index 000000000..c18e5879f --- /dev/null +++ b/tests/test_elixir/utils/small.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + +from tests.test_elixir.utils.mlp import MlpModule +from tests.test_elixir.utils.registry import TEST_MODELS + + +def small_data_fn(): + return dict(x=torch.randint(low=0, high=20, size=(4, 8))) + + +class SmallModel(nn.Module): + + def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None: + super().__init__() + self.embed = nn.Embedding(num_embeddings, hidden_dim) + self.norm1 = nn.LayerNorm(hidden_dim) + self.mlp = MlpModule(hidden_dim=hidden_dim) + self.norm2 = nn.LayerNorm(hidden_dim) + self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False) + self.proj.weight = self.embed.weight + + def forward(self, x): + x = self.embed(x) + x = x + self.norm1(self.mlp(x)) + x = self.proj(self.norm2(x)) + x = x.mean(dim=-2) + return x.sum() + + +TEST_MODELS.register('small', SmallModel, small_data_fn)