[elixir] add elixir and its unit tests (#3835)

* [elixir] add elixir

* [elixir] add unit tests

* remove useless code

* fix python 3.8 issue

* fix typo

* add test skip

* add docstrings

* add docstrings

* add readme

* fix typo
pull/3864/head
Haichen Huang 2023-05-29 09:32:37 +08:00 committed by GitHub
parent 34966378e8
commit 206280408a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
86 changed files with 6627 additions and 2 deletions

View File

@ -0,0 +1,71 @@
# Elixir (Gemini2.0)
Elixir, also known as Gemini, is a technology designed to facilitate the training of large models on a small GPU cluster.
Its goal is to eliminate data redundancy and leverage CPU memory to accommodate really large models.
In addition, Elixir automatically profiles each training step prior to execution and selects the optimal configuration for the ratio of redundancy and the device for each parameter.
This repository is used to benchmark the performance of Elixir.
Elixir will be integrated into ColossalAI for usability.
## Environment
This version is a beta release, so the running environment is somewhat restrictive.
We are only demonstrating our running environment here, as we have not yet tested its compatibility.
We have set the CUDA version to `11.6` and the PyTorch version to `1.13.1+cu11.6`.
## Examples
Here is a simple example to wrap your model and optimizer for [fine-tuning](https://github.com/hpcaitech/Elixir/tree/main/example/fine-tune).
```python
from elixir.search import minimum_waste_search
from elixir.wrapper import ElixirModule, ElixirOptimizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-8)
sr = minimum_waste_search(model, world_size)
model = ElixirModule(model, sr, world_group)
optimizer = ElixirOptimizer(model, optimizer)
```
Here is an advanced example for performance, which is used in our [benchmark](https://github.com/hpcaitech/Elixir/blob/main/example/common/elx.py).
```python
import torch
import torch.distributed as dist
from colossalai.nn.optimizer import HybridAdam
from elixir.wrapper import ElixirModule, ElixirOptimizer
# get the world communication group
global_group = dist.GroupMember.WORLD
# get the communication world size
global_size = dist.get_world_size()
# initialize the model in CPU
model = get_model(model_name)
# HybridAdam allows a part of parameters updated on CPU and a part updated on GPU
optimizer = HybridAdam(model.parameters(), lr=1e-3)
sr = optimal_search(
model,
global_size,
unified_dtype=torch.float16, # enable for FP16 training
overlap=True, # enable for overlapping communications
verbose=True, # print detailed processing information
inp=data, # proivde an example input data in dictionary format
step_fn=train_step # provide an example step function
)
model = ElixirModule(
model,
sr,
global_group,
prefetch=True, # prefetch chunks to overlap communications
dtype=torch.float16, # use AMP
use_fused_kernels=True # enable fused kernels in Apex
)
optimizer = ElixirOptimizer(
model,
optimizer,
initial_scale=64, # loss scale used in AMP
init_step=True # enable for the stability of training
)
```

View File

@ -0,0 +1 @@
from .wrapper import ElixirModule, ElixirOptimizer

View File

@ -0,0 +1,2 @@
from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
from .fetcher import ChunkFetcher

View File

@ -0,0 +1,4 @@
from .chunk import Chunk
from .group import ChunkGroup
from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
from .states import TensorState

View File

@ -0,0 +1,567 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.tensor import FakeTensor
from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock
from .states import TensorState, ts_update_sanity_check
class ChunkFullError(Exception):
pass
@dataclass
class TensorInfo:
state: TensorState
fake_data: FakeTensor
offset: int
end: int
class Chunk:
"""Chunk is a type of data structure to store tensors.
It allows us to store a sequence of tensors into one continuous memory block.
Moreover, Chunk manages the storage of tensors in a distributed way.
Normally, a chunk is scattered across its process group.
When a tensor in this chunk should be used later, the chunk can be gathered by access_chunk.
When the training is done, the chunk can be scattered by reduce_chunk.
args:
rcache: the memory pool to store replicated chunks
chunk_size: the size of the chunk
chunk_dtype: the dtype of the chunk
process_group: the torch communication group of the chunk
temp_device: the device to store the temporary chunk when initializing
shard_device: the device to store the shard of the scattered chunk
rcache_fused: whether this chunk is fused in rcache without eviction
cpu_pin_memory: whether this chunk use cpu pin memory for its shard
"""
total_count = 0
def __init__(
self,
rcache: MemoryPool,
chunk_size: int,
chunk_dtype: torch.dtype,
process_group: ProcessGroup,
temp_device: Optional[torch.device] = None,
shard_device: Optional[torch.device] = None,
rcache_fused: bool = False, # whether this chunk is used in ZeRO2
cpu_pin_memory: bool = False # whether this chunk has a permanent copy in cpu
) -> None:
self.chunk_id: int = Chunk.total_count
Chunk.total_count += 1
# set replicated cache pool
self.rcache: MemoryPool = rcache
self.chunk_size: int = chunk_size
self.chunk_dtype: torch.dtype = chunk_dtype
self.utilized_size: int = 0
self.torch_pg: ProcessGroup = process_group
self.pg_size: int = dist.get_world_size(self.torch_pg)
self.pg_rank: int = dist.get_rank(self.torch_pg)
# the chunk size should be divisible by the dp degree
assert chunk_size % self.pg_size == 0
self.shard_size: int = chunk_size // self.pg_size
self.shard_begin: int = self.shard_size * self.pg_rank
self.shard_end: int = self.shard_begin + self.shard_size
self.valid_end: int = self.shard_size + 1 # set to an illegal number
# notice: release blocks reserved by Pytorch
torch.cuda.empty_cache()
# rcache block, the global replicated chunk in R cache
self.rcb: Optional[TensorBlock] = None
self.rcache_fused: bool = rcache_fused
self._my_block = None
self.is_replica: bool = True
# allocate a private block for fused chunks
if self.rcache_fused:
self._my_block = rcache.get_private_block(chunk_size, chunk_dtype)
temp_device: torch.device = temp_device or gpu_device()
# chunk_temp is a global chunk, which only exists during building the chunks.
# keep all elements to zero
self.chunk_temp: Optional[torch.Tensor] = None
if rcache_fused:
self.chunk_temp = self._my_block.payload
torch.zero_(self.chunk_temp)
else:
self.chunk_temp = torch.zeros(chunk_size, dtype=chunk_dtype, device=temp_device)
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
shard_device: torch.device = shard_device or torch.device('cpu')
pin_flag: bool = cpu_pin_memory and shard_device.type == 'cpu'
# chunk.shard is a local chunk
# it is desinged to exist permanently
self.shard: torch.Tensor = torch.empty(self.shard_size,
dtype=chunk_dtype,
device=shard_device,
pin_memory=pin_flag)
# calculate the memory occupation of the chunk and the shard
self.chunk_memo: int = self.chunk_size * self.chunk_temp.element_size()
self.shard_memo: int = self.chunk_memo // self.pg_size
# each tensor is associated with a TensorInfo to track its meta info
# (state, shape, offset, end)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of tensors in the chunk
self.num_tensors: int = 0
# Record the number of tensors in different states
self.tensor_state_cnter: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensor_state_cnter[state] = 0
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self.paired_chunk = None
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# whether to record l2 norm for the gradient clipping calculation
self.l2_norm_flag = False
self.l2_norm = None
# whether it overflows after the reduction
self.overflow = False
@property
def prepared_block(self):
return self._my_block
@property
def is_init(self):
return self.chunk_temp is not None
@property
def in_rcache(self):
return self.rcb is not None
@property
def shard_device(self):
return self.shard.device
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
cpu_memory = 0
# this chunk is not closed
if self.is_init:
if self.chunk_temp.device.type == 'cuda':
cuda_memory += self.chunk_memo
else:
cpu_memory += self.chunk_memo
# this chunk is on the rcache
if self.in_rcache:
cuda_memory += self.rcb.memo_occ
# calculate the occupation of the chunk shard
if self.shard_device.type == 'cuda':
cuda_memory += self.shard_memo
elif self.shard_device.type == 'cpu':
cpu_memory += self.shard_memo
else:
raise NotImplementedError
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def payload(self) -> torch.Tensor:
if self.is_init:
return self.chunk_temp
if self.in_rcache:
return self.rcb.payload
else:
return self.shard
@property
def shard_move_check(self) -> bool:
return not self.in_rcache
def _not_compute_number(self):
total = 0
state_list = [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE]
for state in state_list:
total += self.tensor_state_cnter[state]
return total
@property
def scatter_check(self) -> bool:
if self.rcache_fused:
return False
return self._not_compute_number() == self.num_tensors
@property
def reduce_check(self):
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
def set_overflow_flag(self, valid_tensor: torch.Tensor) -> None:
assert not self.overflow
self.overflow = torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def set_l2_norm(self, valid_tensor: torch.Tensor) -> None:
assert self.l2_norm is None, 'you are calculating the l2 norm twice'
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item()**2
def append_tensor(self, tensor: torch.Tensor):
# sanity check
assert self.is_init
assert tensor.dtype == self.chunk_dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
fake_data = FakeTensor(tensor.data)
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensor_state_cnter[tensor_state] += 1
self.tensors_info[tensor] = TensorInfo(state=tensor_state,
fake_data=fake_data,
offset=self.utilized_size,
end=new_utilized_size)
self.utilized_size = new_utilized_size
def close_chunk(self):
# sanity check
assert self.is_init
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
self.__remove_tensors_ptr()
self.__update_shard(self.chunk_temp, self.shard)
self.is_replica = False
self.chunk_temp = None
def replicate(self):
assert not self.is_replica
self.is_replica = True
this_shard = self.shard if self.optim_sync_flag else self.__paired_shard()
self.__update_replica(self.rcb.payload, this_shard)
self.__update_tensors_ptr()
def scatter(self):
assert not self.rcache_fused
assert self.is_replica
self.__remove_tensors_ptr()
if not self.optim_sync_flag:
self.__update_shard(self.rcb.payload, self.shard)
self.optim_sync_flag = True
self.is_replica = False
def reduce(self, always_fp32: bool = False):
assert self.is_replica
self.__remove_tensors_ptr()
if self.pg_size > 1:
cast_to_fp32 = False
if always_fp32 and self.chunk_dtype != torch.float:
cast_to_fp32 = True
# cast the payload to fp32
reduce_buffer = self.rcb.payload.to(dtype=torch.float)
else:
# otherwise, use the same payload
reduce_buffer = self.rcb.payload
# divide the reduce buffer by the size of the process group
reduce_buffer /= self.pg_size
# try to use inplace reduce scatter
# notice: pytorch does not allow true inplace reduce scatter
# because pytorch will allocate a continuous memory space for collective communications
shard_buffer = reduce_buffer[self.shard_begin:self.shard_end]
dist.reduce_scatter_tensor(shard_buffer, reduce_buffer, group=self.torch_pg)
# the result should be moved to payload for norm calculating
if cast_to_fp32:
calc_buffer = self.rcb.payload[self.shard_begin:self.shard_end]
calc_buffer.copy_(shard_buffer)
else:
# if process group size equals to 1, do not communicate
reduce_buffer = self.rcb.payload
self.__update_shard(reduce_buffer, self.shard)
self.is_replica = False
def access_chunk(self, block: Optional[TensorBlock] = None):
# sanity check
assert not self.is_init
assert not self.is_replica
if self.rcache_fused:
assert block is None
self.rcb = self._my_block
else:
assert block in self.rcache.public_used_blocks
assert self.rcb is None
self.rcb = block
self.replicate()
def release_chunk(self) -> TensorBlock:
# sanity check
assert not self.is_init
assert self.is_replica
if self.rcache_fused:
raise RuntimeError
self.scatter()
block = self.rcb
self.rcb = None
return block
def update_extra_reduce_info(self, block: Optional[TensorBlock]):
if self.rcache_fused:
assert block is None
block = self._my_block
else:
assert block is not None
buffer = block.payload[self.shard_begin:self.shard_end]
valid_tensor = buffer[:self.valid_end]
self.set_overflow_flag(valid_tensor)
if self.l2_norm_flag:
self.set_l2_norm(valid_tensor)
def reduce_chunk(self, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]:
"""Reduce scatter all the gradients. It's an operation done in CUDA.
"""
# sanity check
assert not self.is_init
assert self.is_replica
self.reduce(always_fp32=always_fp32)
self.__update_tensors_state(TensorState.HOLD)
# reset the rcb pointer
block = self.rcb
self.rcb = None
if self.rcache_fused:
block = None
if sync:
self.update_extra_reduce_info(block)
return block
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
prev_state = self.tensors_info[tensor].state
if prev_state == tensor_state:
return
if ts_update_sanity_check(prev_state, tensor_state):
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
# sanity check
assert self.is_replica
info = self.tensors_info[tensor]
payload = self.rcb.payload
payload[info.offset:info.end].copy_(data_slice.data.flatten())
tensor.data = payload[info.offset:info.end].view(tensor.shape)
def init_pair(self, friend_chunk: 'Chunk') -> None:
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
else:
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
"""
# sanity check
assert self.paired_chunk is not None
friend_chunk: Chunk = self.paired_chunk
assert not friend_chunk.is_replica
# gradient and optimizer should be on the same device
assert self.shard_device.type == friend_chunk.shard_device.type
if self.shard_device.type == 'cuda':
self.shard.copy_(friend_chunk.shard)
self.optim_sync_flag = True
elif self.shard_device.type == 'cpu':
# optim_sync_flag is set to False
# see shard_move function for more details
self.optim_sync_flag = False
else:
raise NotImplementedError
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def get_cpu_copy(self, only_rank_0: bool = False) -> List[torch.Tensor]:
assert not self.is_init
if self.is_replica:
# use the payload directly when being replica
temp_buffer = self.rcb.payload
else:
# otherwise, create a temporary buffer
temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device())
# cheat the assertion in __update_replica
self.is_replica = True
self.__update_replica(temp_buffer, self.shard)
self.is_replica = False
cpu_copys = [None] * self.num_tensors
if not only_rank_0 or self.pg_rank == 0:
for i, (t, info) in enumerate(self.tensors_info.items()):
t_copy = temp_buffer[info.offset:info.end].view(t.shape).cpu()
cpu_copys[i] = t_copy
# synchronize
dist.barrier()
return cpu_copys
def load_tensors(self, tensor_list: List[Optional[torch.Tensor]], only_rank_0: bool = False) -> bool:
assert not self.is_replica
assert not self.is_init
temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device())
# cheat the assertion in __update_replica
self.is_replica = True
self.__update_replica(temp_buffer, self.shard)
self.is_replica = False
if not only_rank_0 or self.pg_rank == 0:
for (_, c_info), load_tensor in zip(self.tensors_info.items(), tensor_list):
if load_tensor is None:
continue
temp_buffer[c_info.offset:c_info.end].copy_(load_tensor.data.flatten())
# synchronize
dist.barrier()
if only_rank_0:
dist.broadcast(temp_buffer, src=0, group=self.torch_pg)
# cheat the assertion in __update_shard
self.is_replica = True
self.__update_shard(temp_buffer, self.shard)
self.is_replica = False
def __update_replica(self, replica: torch.Tensor, shard: torch.Tensor):
assert self.is_replica
assert replica.numel() == self.chunk_size
assert shard.numel() == self.shard_size
buffer = replica[self.shard_begin:self.shard_end]
buffer.copy_(shard)
dist.all_gather_into_tensor(replica, buffer, group=self.torch_pg)
def __update_shard(self, replica: torch.Tensor, shard: torch.Tensor):
assert self.is_replica
assert replica.numel() == self.chunk_size
assert shard.numel() == self.shard_size
shard.copy_(replica[self.shard_begin:self.shard_end])
def __paired_shard(self):
assert self.paired_chunk is not None, 'chunks should be paired before training'
optim_chunk: Chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.shard_device.type == 'cpu'
return optim_chunk.shard.to(gpu_device())
def __remove_tensors_ptr(self) -> None:
# sanity check
# each tensor should point to its fake data before scatter
assert self.is_replica
for tensor, info in self.tensors_info.items():
tensor.data = info.fake_data
def __update_tensors_ptr(self) -> None:
# sanity check
# the chunk should be replicated to get the correct pointer
assert self.is_replica
payload = self.rcb.payload
for tensor, info in self.tensors_info.items():
tensor.data = payload[info.offset:info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensor_state_cnter[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensor_state_cnter[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return self.chunk_id
def __lt__(self, other: object) -> bool:
return self.chunk_id < other.chunk_id
def __eq__(self, other: object) -> bool:
return self.chunk_id == other.chunk_id
def __repr__(self, detailed: bool = True):
if self.is_init:
state = 'initialization'
elif self.in_rcache:
state = 'replicated'
else:
state = 'scattered'
output = [
f'Chunk {self.chunk_id} details: state -> {state}\n',
f' length: {self.chunk_size}, dtype: {self.chunk_dtype}, group_size: {self.pg_size}, tensors: {self.num_tensors}\n'
f' utilized size: {self.utilized_size}, utilized percentage: {100 * (self.utilized_size / self.chunk_size):.0f}%\n'
]
memory_info = self.memory_usage
output.append(' memory usage: (cuda -> {}, cpu -> {})\n'.format(memory_info['cuda'], memory_info['cpu']))
def print_tensor(name, tensor, prefix=''):
output.append(f'{prefix}{name}: (shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device})\n')
if self.is_init:
print_tensor(name='temp', tensor=self.chunk_temp, prefix=' ')
if self.in_rcache:
print_tensor(name='block', tensor=self.rcb.payload, prefix=' ')
if self.shard is not None:
print_tensor(name='shard', tensor=self.shard, prefix=' ')
if detailed:
output.append(' tensor state monitor:\n')
for st in TensorState:
output.append(' # of {}: {}\n'.format(st, self.tensor_state_cnter[st]))
return ''.join(output)

View File

@ -0,0 +1,180 @@
from typing import Dict, List, Optional, Set
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from .chunk import Chunk
from .memory_pool import MemoryPool, TensorBlock
from .states import TensorState
class ChunkGroup(object):
"""ChunkGroup manages a group of chunks and their memory pool.
Commonly, one model has one chunk group.
It supports chunk allocation, chunk access, and chunk release.
ChunkGroup is responsible for the memory management before its APIs.
args:
rcache: A memory pool to instantiate chunks.
"""
def __init__(self, rcache: MemoryPool) -> None:
super().__init__()
self.rcache = rcache
self.fused_chunks: Set[Chunk] = set()
self.float_chunks: Set[Chunk] = set()
self.ten_to_chunk: Dict[torch.Tensor, Chunk] = dict()
self.accessed_fused_chunks: Set[Chunk] = set()
self.accessed_float_chunks: Set[Chunk] = set()
def __add_to_accset(self, chunk: Chunk):
if chunk.rcache_fused:
self.accessed_fused_chunks.add(chunk)
else:
self.accessed_float_chunks.add(chunk)
def __remove_from_accset(self, chunk: Chunk):
if chunk.rcache_fused:
self.accessed_fused_chunks.remove(chunk)
else:
self.accessed_float_chunks.remove(chunk)
def __check_new_float_chunk(self, size: int, dtype: torch.dtype):
# if the public space is 0, there is no access operations
if self.rcache.public_space == 0:
return
# otherwise, check its size and dtype
assert size == self.rcache.public_block_size
assert dtype == self.rcache.public_dtype
def inside_check(self, chunk: Chunk) -> None:
"""Check whether the chunk is in this ChunkGroup"""
if chunk.rcache_fused:
assert chunk in self.fused_chunks
else:
assert chunk in self.float_chunks
def is_accessed(self, chunk: Chunk) -> bool:
"""Chech whether the chunk is accessed."""
# sanity check
self.inside_check(chunk)
if chunk.rcache_fused:
return (chunk in self.accessed_fused_chunks)
else:
return (chunk in self.accessed_float_chunks)
def open_chunk(self,
chunk_size: int,
chunk_dtype: torch.dtype,
process_group: ProcessGroup,
chunk_config: Optional[Dict] = None) -> Chunk:
"""Open a chunk to store parameters."""
if chunk_config is None:
chunk_config = {}
chunk = Chunk(rcache=self.rcache,
chunk_size=chunk_size,
chunk_dtype=chunk_dtype,
process_group=process_group,
**chunk_config)
# sanity check
if not chunk.rcache_fused:
self.__check_new_float_chunk(chunk_size, chunk_dtype)
return chunk
def close_chunk(self, chunk: Chunk) -> bool:
"""Close the chunk during the allocation."""
chunk.close_chunk()
# add the new chunk to the set of allocated chunks
if chunk.rcache_fused:
self.fused_chunks.add(chunk)
else:
self.float_chunks.add(chunk)
# add the new chunk to the mapping
for t in chunk.get_tensors():
assert t not in self.ten_to_chunk
self.ten_to_chunk[t] = chunk
return True
def allocate_chunk(self,
tensor_list: List[torch.Tensor],
chunk_size: int,
chunk_dtype: torch.dtype,
process_group: ProcessGroup,
chunk_config: Optional[Dict] = None) -> Chunk:
"""Allocate a chunk for a list of parameters."""
chunk = self.open_chunk(chunk_size, chunk_dtype, process_group, chunk_config)
# append tensors
for t in tensor_list:
chunk.append_tensor(t)
self.close_chunk(chunk)
return chunk
def tensors_to_chunks(self, tensor_list: List[torch.Tensor]) -> List[Chunk]:
"""Get the chunks of a gevien list of tensors."""
chunk_list = list()
for tensor in tensor_list:
chunk = self.ten_to_chunk.get(tensor)
if chunk not in chunk_list:
chunk_list.append(chunk)
chunk_list.sort(key=lambda c: c.chunk_id)
return chunk_list
def rcache_enough_check(self, chunk: Chunk) -> bool:
"""Check whether the rcache has enough blocks to store the gathered chunk."""
if chunk.rcache_fused:
return True
return self.rcache.public_free_cnt > 0
def access_chunk(self, chunk: Chunk) -> bool:
"""Access a chunk into rCache."""
self.inside_check(chunk)
# if this chunk is accessed already, return False
if self.is_accessed(chunk):
return False
if chunk.rcache_fused:
block = None
else:
block = self.rcache.get_public_block()
chunk.access_chunk(block)
self.__add_to_accset(chunk)
return True
def release_chunk(self, chunk: Chunk) -> bool:
"""Release a chunk from rCache."""
self.inside_check(chunk)
assert self.is_accessed(chunk)
assert chunk.scatter_check
block = chunk.release_chunk()
if block:
self.rcache.free_public_block(block)
self.__remove_from_accset(chunk)
return True
def reduce_chunk(self, chunk: Chunk, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]:
"""Reduce and scatter a gradient chunk from rCache."""
self.inside_check(chunk)
assert self.is_accessed(chunk)
assert chunk.reduce_check
block = chunk.reduce_chunk(always_fp32=always_fp32, sync=sync)
if block and sync:
# if synchronized, free the block into rcache
self.rcache.free_public_block(block)
block = None
self.__remove_from_accset(chunk)
return block
def tensor_trans_state(self, tensor: torch.Tensor, state: TensorState):
"""Transform the state of a tensor."""
chunk = self.ten_to_chunk.get(tensor)
chunk.tensor_trans_state(tensor, state)

View File

@ -0,0 +1,172 @@
from abc import ABC
from collections import defaultdict
from typing import Iterable, NamedTuple
import torch
from torch.autograd.profiler_util import _format_memory
class BlockRequire(NamedTuple):
numel: int
dtype: torch.dtype
class TensorBlock(ABC):
"""TensorBlock is the memory unit of memory pool.
It is a continuous memory block used to store tensors.
Each chunk needs a corresponding TensorBlock to store its data during training.
args:
numel: the number of elements in the block
dtype: the data type of the block
device_type: the device type of the block
"""
total_count: int = 0
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
self.block_id = TensorBlock.total_count
TensorBlock.total_count += 1
self.device_type = device_type
self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type)
self.memo_occ: int = self.payload.numel() * self.payload.element_size()
@property
def numel(self):
return self.payload.numel()
@property
def dtype(self):
return self.payload.dtype
@property
def device(self):
return self.payload.device
def __hash__(self) -> int:
return self.block_id
def __eq__(self, other: object) -> bool:
return self.block_id == other.block_id
def __repr__(self) -> str:
return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})'
class PublicBlock(TensorBlock):
"""Public blocks have the same length.
Chunks of the same length can share the same public block.
"""
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
super().__init__(numel, dtype, device_type)
self.block_type = 'public'
def __repr__(self) -> str:
return f'PublicBlock{super().__repr__()}'
class PrivateBlock(TensorBlock):
"""Private blocks may have different lengths.
Each private chunk should use its own private block.
"""
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
super().__init__(numel, dtype, device_type)
self.block_type = 'private'
def __repr__(self) -> str:
return f'PrivateBlock{super().__repr__()}'
class MemoryPool(object):
"""A memory pool consists of public blocks and private blocks.
rCache uses memory pool to manage memory bolcks.
Users should allocate memory blocks before using it.
args:
device_type: the device type of the memory pool
"""
def __init__(self, device_type: str) -> None:
self.device_type: str = device_type
self.public_space: int = 0
self.public_block_size: int = 0
self.public_dtype: torch.dtype = None
self.public_free_blocks: list = None
self.public_used_blocks: set = None
self.public_free_cnt: int = 0
self.public_used_cnt: int = 0
self.private_space: int = 0
self.private_blocks: list = None
self.private_lookup_dict: dict[BlockRequire, list] = None
self.__allocate_flag = False
def allocate(self,
public_dtype: torch.dtype = torch.float,
public_block_size: int = 1024,
public_block_number: int = 0,
private_block_list: Iterable[BlockRequire] = ()):
assert self.__allocate_flag is False
assert public_block_number >= 0
self.public_free_blocks = list()
self.public_used_blocks = set()
for _ in range(public_block_number):
block = PublicBlock(public_block_size, public_dtype, self.device_type)
self.public_free_blocks.append(block)
if public_block_number <= 0:
self.public_space = 0
else:
self.public_space = self.public_free_blocks[0].memo_occ * public_block_number
self.public_block_size = public_block_size
self.public_dtype = public_dtype
self.public_free_cnt = public_block_number
self.public_used_cnt = 0
self.private_space = 0
self.private_blocks = list()
self.private_lookup_dict = defaultdict(list)
for require in private_block_list:
block = PrivateBlock(require.numel, require.dtype, self.device_type)
self.private_space += block.memo_occ
self.private_blocks.append(block)
self.private_lookup_dict[require].append(block)
self.__allocate_flag = True
def __repr__(self) -> str:
return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})'
def get_private_block(self, numel: int, dtype: torch.dtype):
block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype))
return block_list.pop()
def get_public_block(self):
self.public_free_cnt -= 1
self.public_used_cnt += 1
block = self.public_free_blocks.pop()
self.public_used_blocks.add(block)
return block
def free_public_block(self, block: TensorBlock):
assert isinstance(block, PublicBlock)
assert block in self.public_used_blocks
self.public_free_cnt += 1
self.public_used_cnt -= 1
self.public_used_blocks.remove(block)
self.public_free_blocks.append(block)
return block

View File

@ -0,0 +1,25 @@
from enum import Enum
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
# expected: free -> hold -> compute -> hold ->
# -> compute -> hold_after_bwd -> ready_for_reduce
legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
def ts_update_sanity_check(old_state, new_state) -> bool:
if (old_state, new_state) not in legal_ts_update_list:
raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}')
return True

View File

@ -0,0 +1,210 @@
from typing import List, Optional
import torch
from .core import Chunk, ChunkGroup, TensorBlock, TensorState
from .scheduler import ChunkScheduler
class ChunkFetcher(object):
"""ChunkFetcher is responsible for fetching and reducing chunks during training.
Any operations on chunks should be done through ChunkFetcher.
args:
scheduler: A ChunkScheduler to schedule evictable chunks.
group: A ChunkGroup to manage chunks.
overlap: Whether to overlap communications.
reduce_always_fp32: Whether to reduce gradients in FP32.
"""
def __init__(self,
scheduler: ChunkScheduler,
group: ChunkGroup,
overlap: bool = False,
reduce_always_fp32: bool = False) -> None:
self.scheduler: ChunkScheduler = scheduler
self.group: ChunkGroup = group
self.reduce_always_fp32 = reduce_always_fp32
self.current_step = -1
self.overlap_flag = overlap
self.main_stream = torch.cuda.current_stream()
self.predict_next_chunk: Optional[Chunk] = None
self.is_fetching: bool = False
self.prefetch_stream = torch.cuda.Stream()
self.reduced_chunk: Optional[Chunk] = None
self.reduced_block: Optional[TensorBlock] = None
self.reduce_stream = torch.cuda.Stream()
def reset(self):
"""Reset the fetcher to the initial state.
Users should call this function before training."""
self.scheduler.reset()
self.current_step = -1
def clear(self):
"""Clear the fetcher.
Users should call this function after training."""
if self.overlap_flag:
torch.cuda.synchronize()
self.predict_next_chunk = None
self.is_fetching = False
if self.reduced_chunk is not None:
self.reduce_call_back()
self.reduced_chunk = None
self.reduced_block = None
self.scheduler.clear()
def trans_to_compute(self, tensors: List[torch.Tensor]):
"""Transform tensors to COMPUTE state.
This function should be called before the compute operators."""
# update tensor states
for t in tensors:
self.group.tensor_trans_state(t, TensorState.COMPUTE)
# chunk operations
chunks = self.group.tensors_to_chunks(tensors)
for chunk in chunks:
self.scheduler.remove(chunk)
return chunks
def trans_to_hold(self, tensors: List[torch.Tensor], phase: str):
"""Transform tensors to HOLD state.
This function should be called after the compute operators."""
assert phase in ('f', 'b')
next_state = TensorState.HOLD if phase == 'f' else TensorState.HOLD_AFTER_BWD
# update tensor states
for t in tensors:
self.group.tensor_trans_state(t, next_state)
# chunk operations
chunks = self.group.tensors_to_chunks(tensors)
for chunk in chunks:
if chunk.scatter_check:
self.scheduler.add(chunk)
def get_one_chunk(self, tensor: torch.Tensor) -> Chunk:
"""Get the chunk of the given tensor."""
return self.group.ten_to_chunk.get(tensor)
def get_chunks(self, tensors: List[torch.Tensor]) -> List[Chunk]:
"""Get the chunks of the given tensors."""
return self.group.tensors_to_chunks(tensors)
def is_in_fused(self, tensor: torch.Tensor):
"""Check whether the given tensor is in a fused chunk."""
chunk = self.get_one_chunk(tensor)
return chunk.rcache_fused
def filter_chunks(self, chunks: List[Chunk]):
"""Filter the accessed chunks, since they are already on the rCache."""
return list(filter(lambda c: not self.group.is_accessed(c), chunks))
def fetch_chunks(self, chunks: List[Chunk]):
"""Fetch chunks needed for this compute operator.
The chunks should be in the COMPUTE state first."""
# make step + 1
self.step()
predict_hit = False
# try to prefetch the next chunk
if self.predict_next_chunk is not None and self.predict_next_chunk in chunks:
if self.is_fetching:
# prefetch hit, wait async prefetch
self.main_stream.wait_stream(self.prefetch_stream)
# clear prefetch information
self.predict_next_chunk = None
self.is_fetching = False
predict_hit = True
# filter accessed chunks
scattered = self.filter_chunks(chunks)
# sanity check: upload should wait for prefetch
if self.predict_next_chunk is not None:
assert len(scattered) == 0
# all chunks are on the rcache
if len(scattered) == 0:
# prefetch if there is a hit above
if predict_hit:
self.prefetch(chunks)
return
for chunk in scattered:
# if the rcache is not enough, just release a chunk
if not self.group.rcache_enough_check(chunk):
maybe_chunk = self.scheduler.top()
# print(f'Evicting {chunk.chunk_id} -> {maybe_chunk.chunk_id}')
if maybe_chunk is None:
raise RuntimeError('R cache is not enough. Try to allocate more.')
self.scheduler.remove(maybe_chunk)
self.group.release_chunk(maybe_chunk)
# print('Accessing', chunk.chunk_id)
self.group.access_chunk(chunk)
if self.overlap_flag:
assert self.predict_next_chunk is None
self.prefetch(chunks)
def reduce_call_back(self):
self.reduced_chunk.update_extra_reduce_info(self.reduced_block)
if self.reduced_block is not None:
self.group.rcache.free_public_block(self.reduced_block)
def reduce_chunk(self, chunk: Chunk):
"""Reduce and scatter the given gradient chunk."""
if not chunk.reduce_check:
return False
self.scheduler.remove(chunk)
if not self.overlap_flag:
# reduce the chunk if not overlapped
self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=True)
else:
# wait main stream for its computation
self.reduce_stream.wait_stream(self.main_stream)
# main stream should wait reduce stream
# if there is a block recycle
if self.reduced_chunk is not None:
self.main_stream.wait_stream(self.reduce_stream)
self.reduce_call_back()
with torch.cuda.stream(self.reduce_stream):
self.reduced_chunk = chunk
self.reduced_block = self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=False)
def prefetch(self, chunks: List[Chunk]):
"""Prefetch the next used chunk."""
next_chunk = self.scheduler.get_next_chunk(chunks)
self.predict_next_chunk = next_chunk
# return if there is no next scattered chunk
if next_chunk is None or self.group.is_accessed(next_chunk):
return
evict_chunk = None
if not self.group.rcache_enough_check(next_chunk):
maybe_chunk = self.scheduler.top()
# if there is no chunk can be evicted, just return
if maybe_chunk is None:
return
# otherwise, release this chunk
self.scheduler.remove(maybe_chunk)
evict_chunk = maybe_chunk
with torch.cuda.stream(self.prefetch_stream):
# wait main stream
self.prefetch_stream.wait_stream(self.main_stream)
self.is_fetching = True
if evict_chunk is not None:
self.group.release_chunk(evict_chunk)
self.group.access_chunk(next_chunk)
def step(self):
"""Update the scheduler."""
self.scheduler.step()
self.current_step += 1

View File

@ -0,0 +1,3 @@
from .base import ChunkScheduler
from .fifo import FIFOScheduler
from .prefetch import PrefetchScheduler

View File

@ -0,0 +1,50 @@
from abc import ABC, abstractmethod
from typing import Optional
from colossalai.elixir.chunk.core import Chunk
class ChunkScheduler(ABC):
"""The base class of all chunk schedulers.
A chunk scherduler stores all releasable chunks.
It provides APIs to add, remove, display releasable chunks.
"""
def __init__(self) -> None:
super().__init__()
self.releasable_set: Optional[set] = None
self.current_step = -1
@abstractmethod
def reset(self) -> None:
self.releasable_set = set()
self.current_step = -1
@abstractmethod
def clear(self) -> None:
# asure the set is empty now
assert not bool(self.releasable_set)
@abstractmethod
def top(self) -> Optional[Chunk]:
# return None if the releasable set is empty
if not self.releasable_set:
return False
return True
@abstractmethod
def add(self, chunk: Chunk) -> bool:
if chunk in self.releasable_set:
return False
self.releasable_set.add(chunk)
return True
@abstractmethod
def remove(self, chunk: Chunk) -> bool:
if chunk not in self.releasable_set:
return False
self.releasable_set.remove(chunk)
return True
def step(self, *args, **kwags):
self.current_step += 1

View File

@ -0,0 +1,40 @@
from typing import Optional
from .base import Chunk, ChunkScheduler
class FIFOScheduler(ChunkScheduler):
"""The FIFO chunk scheduler.
It stores all releasable chunks in a FIFO queue.
"""
def __init__(self) -> None:
super().__init__()
self.fifo_dict: Optional[dict] = None
def reset(self) -> None:
super().reset()
self.fifo_dict = dict()
def clear(self) -> None:
super().clear()
self.fifo_dict = None
def top(self) -> Optional[Chunk]:
if not super().top():
return None
dict_iter = iter(self.fifo_dict)
ret = next(dict_iter)
return ret
def add(self, chunk: Chunk) -> bool:
if not super().add(chunk):
return False
self.fifo_dict[chunk] = True
return True
def remove(self, chunk: Chunk) -> bool:
if not super().remove(chunk):
return False
self.fifo_dict.pop(chunk)
return True

View File

@ -0,0 +1,85 @@
from collections import defaultdict
from typing import Iterable, List, Optional
import torch
from sortedcontainers import SortedSet
from .base import Chunk, ChunkScheduler
class PrefetchScheduler(ChunkScheduler):
"""The prefetch chunk scheduler.
Its top functions gives the furthest used chunk.
"""
def __init__(self, chunk_called_per_step: List[Iterable[Chunk]]) -> None:
super().__init__()
self.chunk_mapping = None
self.evict_set = None
self.search_step = -1
self.chunks_per_step = chunk_called_per_step
self.total_steps = len(chunk_called_per_step)
self.next_step_dict = defaultdict(list)
# initialize the next_step dictionary
for i, c_list in enumerate(chunk_called_per_step):
for c in c_list:
self.next_step_dict[c].append(i)
def _get_next_step(self, chunk: Chunk):
step_list = self.next_step_dict[chunk]
for i in step_list:
if i > self.current_step:
return i
return self.total_steps
def reset(self) -> None:
super().reset()
self.chunk_mapping = dict()
self.evict_set = SortedSet()
self.search_step = -1
def clear(self) -> None:
super().clear()
if torch.is_grad_enabled():
assert self.current_step == self.total_steps - 1
self.chunk_mapping = None
self.evict_set = None
self.search_step = -1
def top(self) -> Optional[Chunk]:
if not super().top():
return None
next_step, chunk = self.evict_set[-1]
return chunk
def add(self, chunk: Chunk) -> bool:
if not super().add(chunk):
return False
value = (self._get_next_step(chunk), chunk)
self.chunk_mapping[chunk] = value
self.evict_set.add(value)
return True
def remove(self, chunk: Chunk) -> bool:
if not super().remove(chunk):
return False
value = self.chunk_mapping[chunk]
self.evict_set.remove(value)
self.chunk_mapping.pop(chunk)
return True
def step(self, *args, **kwags):
super().step(*args, **kwags)
if self.current_step >= self.total_steps:
raise RuntimeError('exceed simulated steps, please modify your profiling `step_fn`')
def get_next_chunk(self, chunks: List[Chunk]):
self.search_step = max(self.search_step, self.current_step + 1)
while self.search_step < self.total_steps:
c_list = self.chunks_per_step[self.search_step]
for c in c_list:
if c not in chunks:
return c
self.search_step += 1
return None

View File

@ -0,0 +1,32 @@
import torch
tensor_creation_methods = dict(tensor=torch.tensor,
sparse_coo_tensor=torch.sparse_coo_tensor,
asarray=torch.asarray,
as_tensor=torch.as_tensor,
as_strided=torch.as_strided,
from_numpy=torch.from_numpy,
from_dlpack=torch.from_dlpack,
frombuffer=torch.frombuffer,
zeros=torch.zeros,
zeros_like=torch.zeros_like,
ones=torch.ones,
ones_like=torch.ones_like,
arange=torch.arange,
range=torch.range,
linspace=torch.linspace,
logspace=torch.logspace,
eye=torch.eye,
empty=torch.empty,
empty_like=torch.empty_like,
empty_strided=torch.empty_strided,
full=torch.full,
full_like=torch.full_like,
quantize_per_tensor=torch.quantize_per_tensor,
quantize_per_channel=torch.quantize_per_channel,
dequantize=torch.dequantize,
complex=torch.complex,
polar=torch.polar,
heaviside=torch.heaviside)
from .meta_ctx import MetaContext

View File

@ -0,0 +1,34 @@
import torch
from colossalai.elixir.ctx import tensor_creation_methods
class MetaContext(object):
"""A context manager that wraps all tensor creation methods in torch.
By default, all tensors will be created in meta.
args:
device_type: The device type of the tensors to be created.
"""
def __init__(self, device_type: str = 'meta') -> None:
super().__init__()
self.device_type = device_type
return None
def __enter__(self):
def meta_wrap(func):
def wrapped_func(*args, **kwargs):
kwargs['device'] = self.device_type
return func(*args, **kwargs)
return wrapped_func
for name, method in tensor_creation_methods.items():
setattr(torch, name, meta_wrap(method))
def __exit__(self, exc_type, exc_val, exc_tb):
for name, method in tensor_creation_methods.items():
setattr(torch, name, method)

28
colossalai/elixir/cuda.py Normal file
View File

@ -0,0 +1,28 @@
import functools
import torch
from torch.cuda._utils import _get_device_index
elixir_cuda_fraction = dict()
@functools.lru_cache()
def gpu_device():
return torch.device(torch.cuda.current_device())
def set_memory_fraction(fraction, device=None):
torch.cuda.set_per_process_memory_fraction(fraction, device)
if device is None:
device = torch.cuda.current_device()
device = _get_device_index(device)
elixir_cuda_fraction[device] = fraction
def get_allowed_memory(device=None):
total_memory = torch.cuda.get_device_properties(device).total_memory
if device is None:
device = torch.cuda.current_device()
device = _get_device_index(device)
fraction = elixir_cuda_fraction.get(device, 1.0)
return int(fraction * total_memory)

View File

@ -0,0 +1,2 @@
from .parameter import HookParam
from .storage import BufferStore

View File

@ -0,0 +1,73 @@
import torch
from colossalai.elixir.chunk import ChunkFetcher
from .storage import BufferStore
def prefwd_postbwd_function(fetcher: ChunkFetcher, store: BufferStore):
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
with torch._C.DisableTorchFunction():
ctx.params = params
chunks = fetcher.trans_to_compute(params)
fetcher.fetch_chunks(chunks)
offset = 0
for p in ctx.params:
if not fetcher.is_in_fused(p):
# we should add parameters to buffer
# because their blocks may be changed
offset = store.insert(p, offset)
return args
@staticmethod
def backward(ctx, *grads):
with torch._C.DisableTorchFunction():
fetcher.trans_to_hold(ctx.params, phase='b')
for p in ctx.params:
if not fetcher.is_in_fused(p):
store.erase(p)
return (None, *grads)
return PreFwdPostBwd.apply
def postfwd_prebwd_function(fetcher: ChunkFetcher, store: BufferStore):
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
with torch._C.DisableTorchFunction():
ctx.params = params
fetcher.trans_to_hold(ctx.params, phase='f')
for p in ctx.params:
if not fetcher.is_in_fused(p):
store.erase(p)
return args
@staticmethod
def backward(ctx, *grads):
with torch._C.DisableTorchFunction():
chunks = fetcher.trans_to_compute(ctx.params)
fetcher.fetch_chunks(chunks)
offset = 0
for p in ctx.params:
if not fetcher.is_in_fused(p):
# we should add parameters to buffer
# because their blocks may be changed
offset = store.insert(p, offset)
return (None, *grads)
return PostFwdPreBwd.apply

View File

@ -0,0 +1,102 @@
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.utils._pytree import tree_map
from colossalai.elixir.chunk import ChunkFetcher
from colossalai.elixir.kernels import fused_torch_functions
from colossalai.elixir.tensor import OutplaceTensor, is_no_hook_op, to_outplace_tensor
from .functions import postfwd_prebwd_function, prefwd_postbwd_function
from .storage import BufferStore
class HookParam(OutplaceTensor, nn.Parameter):
"""HookParam is a special type of tensor that is used to triggered hooks on parameters.
HookParam adds chunk fetching before torch functions.
"""
pre_fwd_func = None
post_fwd_func = None
use_fused_kernel = False
@staticmethod
def attach_fetcher(fetcher: ChunkFetcher, store: BufferStore):
HookParam.pre_fwd_func = prefwd_postbwd_function(fetcher, store)
HookParam.post_fwd_func = postfwd_prebwd_function(fetcher, store)
@staticmethod
def release_fetcher():
HookParam.pre_fwd_func = None
HookParam.post_fwd_func = None
@staticmethod
def enable_fused_kernel():
HookParam.use_fused_kernel = True
@staticmethod
def disable_fused_kernel():
HookParam.use_fused_kernel = False
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if is_no_hook_op(func):
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
return ret
params_to_index = OrderedDict()
params_index = 0
def append_param(x):
nonlocal params_index
if isinstance(x, HookParam):
params_to_index[x] = params_index
params_index += 1
tree_map(append_param, args)
tree_map(append_param, kwargs)
params = tuple(params_to_index.keys())
new_params = HookParam.pre_fwd_func(params, *params)
def replace_param(x):
if isinstance(x, HookParam):
return new_params[params_to_index[x]]
return x
with torch._C.DisableTorchFunction():
if HookParam.use_fused_kernel and func in fused_torch_functions:
func = fused_torch_functions.get(func)
ret = func(*tree_map(replace_param, args), **tree_map(replace_param, kwargs))
if not isinstance(ret, tuple):
ret = (ret,)
ptr_set = set()
for p in new_params:
ptr_set.add(p.data_ptr())
def clone_inplace_tensor(x):
if isinstance(x, torch.Tensor):
start_point = x.data_ptr() - x.element_size() * x.storage_offset()
if start_point in ptr_set:
return x.clone()
return x
ret = tree_map(clone_inplace_tensor, ret)
ret = HookParam.post_fwd_func(params, *ret)
def convert(t):
if isinstance(t, torch.Tensor):
t = to_outplace_tensor(t)
return t
ret = tree_map(convert, ret)
if len(ret) == 1:
return ret[0]
else:
return ret

View File

@ -0,0 +1,56 @@
import torch
from torch.autograd.profiler_util import _format_memory
from colossalai.elixir.cuda import gpu_device
class BufferStore(object):
"""A place to store parameters temporarily when computing.
Parameters should be stored into the buffer before computaions.
args:
buffer_size: The size of the buffer.
buffer_dtype: The dtype of the buffer.
device_str: The device to store the buffer.
"""
def __init__(self, buffer_size: torch.Tensor, buffer_dtype: torch.dtype, device_str: str = 'cuda') -> None:
super().__init__()
self.buffer_size = buffer_size
self.buffer_dtype = buffer_dtype
self.buffer: torch.Tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device_str)
self.buffer_occ = buffer_size * self.buffer.element_size()
self.record_dict = dict()
def zeros(self):
torch.zero_(self.buffer)
def insert(self, t: torch.Tensor, offset: int) -> int:
assert t not in self.record_dict
end = offset + t.numel()
assert end <= self.buffer_size, f'buffer size is {self.buffer_size} but needs {end}'
new_data = self.buffer[offset:end].view(t.shape)
new_data.copy_(t.data)
self.record_dict[t] = t.data
t.data = new_data
return end
def erase(self, t: torch.Tensor):
assert t in self.record_dict
new_data = self.record_dict.pop(t)
t.data = new_data
return
def empty_like(self, t: torch.Tensor):
return self.buffer[:t.numel()].view(t.shape)
def empty_1d(self, size: int):
return self.buffer[:size]
def __repr__(self) -> str:
return f'Buffer(size={self.buffer_size}, dtype={self.buffer_dtype}, memo_occ={_format_memory(self.buffer_occ)})'

View File

@ -0,0 +1,17 @@
import torch
import torch.nn.functional as F
fused_torch_functions = {F.layer_norm: F.layer_norm}
def register_fused_layer_norm():
try:
from .layernorm import ln_func
fused_torch_functions[F.layer_norm] = ln_func
print('Register fused layer norm successfully from apex.')
except:
print('Cannot import fused layer norm, please install apex from source.')
pass
register_fused_layer_norm()

View File

@ -0,0 +1,43 @@
import torch
import xformers.ops as xops
from torch.utils._pytree import tree_map
from colossalai.elixir.tracer.memory_tracer.memory_tensor import MTensor
def lower_triangular_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, p: float = 0.0):
args = (query, key, value)
meta_flag = False
for x in args:
if x.device.type == 'meta':
meta_flag = True
break
if meta_flag:
atten = query @ key.transpose(-2, -1)
output = atten @ value
return output
profile_flag = False
def to_torch_tensor(x):
if isinstance(x, MTensor):
nonlocal profile_flag
profile_flag = True
return x.elem
return x
args = tree_map(to_torch_tensor, args)
query, key, value = args
output = xops.memory_efficient_attention(query=query,
key=key,
value=value,
p=p,
attn_bias=xops.LowerTriangularMask())
if profile_flag:
output = MTensor(output)
return output

View File

@ -0,0 +1,20 @@
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder
from .gpt_attention import XGPT2Attention, XGPT2Model
from .opt_attention import XOPTAttention, XOPTDecoder
def wrap_attention(model: nn.Module):
for name, module in model.named_modules():
if isinstance(module, GPT2Model):
module.__class__ = XGPT2Model
elif isinstance(module, GPT2Attention):
module.__class__ = XGPT2Attention
elif isinstance(module, OPTAttention):
module.__class__ = XOPTAttention
elif isinstance(module, OPTDecoder):
module.__class__ = XOPTDecoder
return model

View File

@ -0,0 +1,45 @@
import einops
import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
from .attention import lower_triangular_attention
class XGPT2Attention(GPT2Attention):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
assert self.scale_attn_weights
assert not self.is_cross_attention
assert not self.scale_attn_by_inverse_layer_idx
assert not self.reorder_and_upcast_attn
b_size, h_size, m_size, k_size = query.size()
assert self.bias.size(-1) == m_size
query = einops.rearrange(query, 'b h m k -> b m h k')
key = einops.rearrange(key, 'b h m k -> b m h k')
value = einops.rearrange(value, 'b h m k -> b m h k')
drop_rate = self.attn_dropout.p
output = lower_triangular_attention(query, key, value, p=drop_rate)
ret = einops.rearrange(output, 'b m h k -> b h m k')
return ret, None
class XGPT2Model(GPT2Model):
def forward(self, *args, **kwargs):
assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg'
attn_mask = kwargs.get('attention_mask')
# assert torch.all(attn_mask == 1), 'only accept no padding mask'
head_mask = kwargs.get('head_mask', None)
assert head_mask is None, 'head mask should be None'
output_attn = kwargs.get('output_attentions', False)
if output_attn:
Warning('output_attentions is not supported for XGPT2Model')
return super().forward(*args, **kwargs)

View File

@ -0,0 +1,10 @@
from apex.normalization.fused_layer_norm import fused_layer_norm, fused_layer_norm_affine
def ln_func(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is None:
assert bias is None
return fused_layer_norm(input, normalized_shape, eps)
else:
assert weight is not None and bias is not None
return fused_layer_norm_affine(input, weight, bias, normalized_shape, eps)

View File

@ -0,0 +1,86 @@
from typing import Optional, Tuple
import einops
import torch
import torch.nn as nn
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder
from .attention import lower_triangular_attention
class XOPTAttention(OPTAttention):
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
assert is_cross_attention is False
assert past_key_value is None
assert layer_head_mask is None
# assert output_attentions is False
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
# get key, value proj
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
src_len = key_states.size(1)
assert tgt_len == src_len
attn_output = lower_triangular_attention(query=query_states, key=key_states, value=value_states, p=self.dropout)
if attn_output.size() != (bsz, tgt_len, self.num_heads, self.head_dim):
raise ValueError(f'`attn_output` should be of size {(bsz, tgt_len, self.num_heads, self.head_dim)}, but is'
f' {attn_output.size()}')
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
class XOPTDecoder(OPTDecoder):
def forward(self, *args, **kwargs):
assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg'
attn_mask = kwargs.get('attention_mask')
# assert torch.all(attn_mask == 1), 'only accept no padding mask'
head_mask = kwargs.get('head_mask', None)
assert head_mask is None, 'head mask should be None'
output_attn = kwargs.get('output_attentions', False)
if output_attn:
Warning('output_attentions is not supported for XOPTDecoder')
return super().forward(*args, **kwargs)

View File

@ -0,0 +1,4 @@
from .mini_waste import minimum_waste_search
from .optimal import optimal_search
from .result import ChunkPlan, SearchResult
from .simple import simple_search

View File

@ -0,0 +1,135 @@
from abc import ABC, abstractmethod
from functools import partial
from typing import List, Tuple
import torch
import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool
from colossalai.elixir.tracer.param_tracer import generate_tf_order
from colossalai.elixir.tracer.utils import meta_copy
from colossalai.elixir.utils import print_rank_0
from .result import ChunkPlan
from .utils import to_meta_tensor
class SearchBase(ABC):
"""A basic class for search algorithms.
args:
module: the model to be searched
dtype: the unified dtype of all parameters
prefetch: whether to prefetch chunks during training
verbose: whether to print search details
inp: a dictionary, the example input of the model
step_fn: the example step function of the model
"""
def __init__(self,
module: nn.Module,
dtype: torch.dtype = torch.float,
prefetch: bool = False,
verbose: bool = False,
inp=None,
step_fn=None) -> None:
self.unified_dtype = dtype
self.meta_module = meta_copy(module, partial(to_meta_tensor, dtype=self.unified_dtype))
self.prefetch_flag = prefetch
self.verbose = verbose
self.param_to_name = {param: name for name, param in self.meta_module.named_parameters()}
self.public_block_size = 1024
self.public_block_number = 0
self.param_per_step = None
self.max_checkpoint_size = 0
if self.prefetch_flag:
assert inp is not None and step_fn is not None
tf_running_info = generate_tf_order(self.meta_module, inp, step_fn, dtype)
self.param_per_step = tf_running_info.get('params_per_step')
if self.verbose:
print_rank_0('Prefetch enabled: the called order of parameters')
for i, step in enumerate(self.param_per_step):
print_rank_0(f'step {i}: {step}')
name_to_param = {name: param for name, param in self.meta_module.named_parameters()}
for checkpoint in tf_running_info.get('checkpoint_info'):
sum_numel = 0
for i in range(*checkpoint):
for name in self.param_per_step[i]:
param = name_to_param[name]
sum_numel += param.numel()
self.max_checkpoint_size = max(self.max_checkpoint_size, sum_numel)
if self.verbose:
print_rank_0(f'checkpoint infomation: from-to -> {checkpoint}, numel -> {sum_numel}')
@abstractmethod
def private_truncate(self, param: nn.Parameter) -> int:
"""A function used to truncate the length of a private chunk,
which only contains one parameter.
"""
pass
@abstractmethod
def public_trucate(self, length: int) -> int:
"""A function used to trucate the length of all publick chunks
"""
pass
@abstractmethod
def search(self, *args, **kwargs) -> Tuple:
"""The core search function. It returns a tuple of a private group and public groups.
"""
pass
def generate_chunk_plans(self, private_group, publick_groups) -> List[ChunkPlan]:
plans = list()
for param in private_group:
chunk_size = self.private_truncate(param)
chunk_dtype = param.dtype
chunk_kwargs = dict(rcache_fused=True)
chunk_plan = ChunkPlan(name_list=[self.param_to_name[param]],
chunk_size=chunk_size,
chunk_dtype=chunk_dtype,
kwargs=chunk_kwargs)
plans.append(chunk_plan)
self.public_block_size = self.public_trucate(self.public_block_size)
public_chunk_size = self.public_block_size
public_chunk_dtype = self.unified_dtype
for group in publick_groups:
chunk_kwargs = {}
chunk_plan = ChunkPlan(name_list=[self.param_to_name[p] for p in group],
chunk_size=public_chunk_size,
chunk_dtype=public_chunk_dtype,
kwargs=chunk_kwargs)
plans.append(chunk_plan)
if self.verbose:
print_rank_0(f'Chunk plans: total {len(plans)} chunks')
for i, plan in enumerate(plans):
print_rank_0(f'plan {i}: {plan}')
return plans
def allocate_chunk_group(self, chunk_plans: List[ChunkPlan]) -> ChunkGroup:
block_require_list = list()
for plan in chunk_plans:
kwargs = plan.kwargs
if kwargs.get('rcache_fused', False):
block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype))
mp = MemoryPool('cuda')
mp.allocate(public_dtype=self.unified_dtype,
public_block_size=self.public_block_size,
public_block_number=self.public_block_number,
private_block_list=block_require_list)
if self.verbose:
print_rank_0(
f'Memory pool (rcache): {mp}\n\tblock size -> {mp.public_block_size}, block number -> {mp.public_free_cnt}'
)
return ChunkGroup(mp)

View File

@ -0,0 +1,167 @@
import math
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.utils import print_rank_0
from .base import SearchBase
from .result import SearchResult
from .utils import find_minimum_waste_size, find_search_range, get_multi_used_params, to_divide
dtype_to_es = {torch.float16: 2, torch.float32: 4, torch.float64: 8}
class SearchMiniWaste(SearchBase):
"""Search the best chunk size to minimize the waste of memory.
args:
module: the module to be searched
default_group_size: the default group size of communications
dtype: the data type of the parameters
prefetch: whether to prefetch the parameters
verbose: whether to print the search details
inp: a dictionary, the example input of the model
step_fn: the example step function of training
"""
def __init__(self,
module: nn.Module,
default_group_size: int,
dtype: torch.dtype = torch.float,
prefetch: bool = False,
verbose: bool = False,
inp=None,
step_fn=None) -> None:
super().__init__(module, dtype, prefetch, verbose, inp, step_fn)
self.default_group_size = default_group_size
def private_truncate(self, param: nn.Parameter) -> int:
return to_divide(param.numel(), self.default_group_size)
def public_trucate(self, length: int) -> int:
return to_divide(length, self.default_group_size)
def search(self) -> Tuple:
min_chunk_size, max_chunk_size, search_interval = find_search_range(self.meta_module)
# get multi-used parameters
private_params = get_multi_used_params(self.meta_module)
# get parameters used only one time
public_params = [p for p in self.meta_module.parameters() if p not in private_params]
# collect the number of elements of each parameter
public_numels = [p.numel() for p in public_params]
# calculate the sumary of all parameters
total_size = sum(public_numels)
if total_size <= min_chunk_size:
public_block_size = total_size
waste_size = 0
else:
public_block_size, waste_size = find_minimum_waste_size(
# pre-commit: do not rearrange
numel_group_list=[public_numels],
min_range=min_chunk_size,
max_range=max_chunk_size,
interval=search_interval)
if self.verbose:
if total_size == 0:
waste_percentage = 0
else:
waste_percentage = 100 * waste_size / total_size
print_rank_0(
f'Minimum waste search result: chunk size = {public_block_size}, waste percentage = {waste_percentage: .1f} %'
)
# initialize the mapping from parameters to chunks
param_to_chunk_id = dict()
chunk_id = 0
# deal with private parameters
for p in private_params:
param_to_chunk_id[p] = chunk_id
chunk_id += 1
# record the upper bound
private_id_upperbound = chunk_id
# deal with public parameters
last_left = 0
for p in public_params:
p_size = p.numel()
if last_left < p_size:
last_left = public_block_size
chunk_id += 1
assert last_left >= p_size
last_left -= p_size
param_to_chunk_id[p] = chunk_id
# initailize public groups
public_number_chunks = chunk_id - private_id_upperbound
public_groups = [[] for _ in range(public_number_chunks)]
for p in public_params:
public_chunk_id = param_to_chunk_id[p] - private_id_upperbound - 1
public_groups[public_chunk_id].append(p)
# calculate the number of minimum chunks allocated in R cache
max_lived_chunks = 0
for module in self.meta_module.modules():
param_set = set()
for param in module.parameters(recurse=False):
param_set.add(param_to_chunk_id[param])
max_lived_chunks = max(max_lived_chunks, len(param_set))
# allocate more chunks for prefetch
if self.prefetch_flag:
max_lived_chunks = min(max_lived_chunks + 4, public_number_chunks)
if total_size == 0:
max_lived_chunks = 0
self.public_block_size = public_block_size
self.public_block_number = max_lived_chunks
return (private_params, public_groups)
def minimum_waste_search(m: nn.Module,
group_size: int,
unified_dtype: torch.dtype = torch.float,
cpu_offload: bool = False,
prefetch: bool = False,
verbose: bool = False,
pin_memory: bool = True,
inp=None,
step_fn=None) -> SearchResult:
search_class = SearchMiniWaste(
# pre-commit: do not rearrange
module=m,
default_group_size=group_size,
dtype=unified_dtype,
prefetch=prefetch,
verbose=verbose,
inp=inp,
step_fn=step_fn)
private_group, public_groups = search_class.search()
chunk_plans = search_class.generate_chunk_plans(private_group, public_groups)
# assign shard device
if cpu_offload:
shard_device = torch.device('cpu')
else:
shard_device = gpu_device()
for plan in chunk_plans:
plan.kwargs['shard_device'] = shard_device
if cpu_offload:
plan.kwargs['cpu_pin_memory'] = pin_memory
chunk_group = search_class.allocate_chunk_group(chunk_plans)
return SearchResult(chunk_group=chunk_group,
chunk_plans=chunk_plans,
param_called_per_step=search_class.param_per_step)

View File

@ -0,0 +1,284 @@
import math
from typing import Tuple
import torch
import torch.nn as nn
from torch.autograd.profiler_util import _format_memory
from colossalai.elixir.cuda import get_allowed_memory, gpu_device
from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling
from colossalai.elixir.utils import calc_buffer_size, print_rank_0
from .base import SearchBase
from .result import SearchResult
from .simulator import find_optimal_chunk_size, rcache_prioirity_check
from .utils import find_search_range, get_multi_used_params, to_divide
dtype_to_es = {torch.float16: 2, torch.float32: 4, torch.float64: 8}
class SearchOptimal(SearchBase):
"""Search the best chunk size to maximize the training throughput.
Users should provide the example input data and step function of training.
args:
module: the module to be searched
default_group_size: the default group size of communications
activation_fragment_factor: the factor to estimate the total activation memory usage
allocation_fragment_factor: the factor to estimate the effective ratio of the memory usage can be used by Elixir
driver_usage: the memory usage of the cuda driver
dtype: the data type of the parameters
verbose: whether to print the search details
overlap: whether to overlap the communication and computation
inp: a dictionary, the example input of the model
step_fn: the example step function of training
"""
def __init__(self,
module: nn.Module,
default_group_size: int,
activation_fragment_factor: float = 1.25,
allocation_fragment_factor: float = 0.95,
driver_usage: float = 2 * 1024**3,
dtype: torch.dtype = torch.float,
verbose: bool = False,
overlap: bool = False,
inp=None,
step_fn=None) -> None:
# as for optimal search, we must profile the model first
super().__init__(module, dtype, True, verbose, inp, step_fn)
# profile cuda memory usage
memo_usage = cuda_memory_profiling(model=self.meta_module, inp=inp, step_fn=step_fn, dtype=dtype)
torch.cuda.empty_cache()
buffer_occ = memo_usage['buffer_occ']
# get the maximum memory usage of activation
predict_activation = memo_usage['activation_occ']
# calculate the total capacity of the current device
gpu_memory = get_allowed_memory()
# allowed capacity = allocation_fragment_factor * (total capacity - activation_fragment_factor * activation)
self.cuda_capacity = int(
allocation_fragment_factor *
(gpu_memory - driver_usage - buffer_occ - activation_fragment_factor * predict_activation))
hook_buffer_store_size = calc_buffer_size(m=self.meta_module, test_dtype=self.unified_dtype)
self.cuda_elements = self.cuda_capacity // dtype_to_es.get(dtype) - hook_buffer_store_size
if self.cuda_elements < 0:
raise RuntimeError('optimal search: activation is too large, please reduce batch size')
if self.verbose:
print_rank_0('Predict memory usage:')
for k, v in memo_usage.items():
print_rank_0(f'{k}: {_format_memory(v)}')
print_rank_0(f'allowed allocation space: {_format_memory(self.cuda_capacity)}')
print_rank_0(f'hook buffer store size: {hook_buffer_store_size}')
print_rank_0(f'allowed {dtype} elements: {self.cuda_elements}')
self.default_group_size = default_group_size
self.comm_overlap = overlap
def private_truncate(self, param: nn.Parameter) -> int:
return to_divide(param.numel(), self.default_group_size)
def public_trucate(self, length: int) -> int:
return to_divide(length, self.default_group_size)
def search(self) -> Tuple:
min_chunk_size, max_chunk_size, search_interval = find_search_range(self.meta_module)
# get multi-used parameters
private_params = get_multi_used_params(self.meta_module)
# subtract the footprint of fused parameters
for param in private_params:
self.cuda_elements -= param.numel()
if self.cuda_elements < 0:
raise RuntimeError('optimal search: no enough space for fused parameters')
# initialize public params in the called order
public_params = list()
public_param_set = set()
name_to_param = {name: param for name, param in self.meta_module.named_parameters()}
for name_set in self.param_per_step:
for name in name_set:
param = name_to_param.get(name)
if param in private_params or param in public_param_set:
continue
public_params.append(param)
public_param_set.add(param)
del name_to_param
del public_param_set
# collect the number of elements of each parameter
public_numels = [p.numel() for p in public_params]
# calculate the sumary of all parameters
total_size = sum(public_numels)
# collect the name for each public parameters
public_param_names = [self.param_to_name[p] for p in public_params]
if total_size <= min_chunk_size:
public_block_size = total_size
n_blocks = 1
waste_size = 0
else:
public_block_size, n_blocks, waste_size = find_optimal_chunk_size(
# pre-commit: do not rearrange
param_per_step=self.param_per_step,
param_names=public_param_names,
param_numels=public_numels,
cuda_elements=self.cuda_elements,
overlap=self.comm_overlap,
min_range=min_chunk_size,
max_range=max_chunk_size,
interval=search_interval)
# truncate the size of public blocks
public_block_size = self.public_trucate(public_block_size)
if self.cuda_elements < n_blocks * public_block_size:
raise RuntimeError('no enough space for unfused parameters')
if self.verbose:
if total_size == 0:
waste_percentage = 0
else:
waste_percentage = 100 * waste_size / total_size
print_rank_0(
f'Optimal search result: chunk size = {public_block_size}, waste percentage = {waste_percentage: .1f} %'
)
# initialize the mapping from parameters to chunks
param_to_chunk_id = dict()
chunk_id = 0
# deal with private parameters
for p in private_params:
param_to_chunk_id[p] = chunk_id
chunk_id += 1
# record the upper bound
private_id_upperbound = chunk_id
# deal with public parameters
last_left = 0
for p in public_params:
p_size = p.numel()
if last_left < p_size:
last_left = public_block_size
chunk_id += 1
assert last_left >= p_size
last_left -= p_size
param_to_chunk_id[p] = chunk_id
# initailize public groups
public_number_chunks = chunk_id - private_id_upperbound
public_groups = [[] for _ in range(public_number_chunks)]
for p in public_params:
public_chunk_id = param_to_chunk_id[p] - private_id_upperbound - 1
public_groups[public_chunk_id].append(p)
if total_size == 0:
n_blocks = 0
self.public_block_size = public_block_size
self.public_block_number = n_blocks
return (private_params, public_groups)
def configure_rcache_size(self, chunk_plans: list, os_factor: int):
element_os = 4
if self.unified_dtype == torch.float16:
element_pa = 2
elif self.unified_dtype == torch.float:
element_pa = 4
else:
raise NotImplementedError
priority = rcache_prioirity_check(n=self.default_group_size, r_os=os_factor, e_p=element_pa, e_o=element_os)
if self.verbose:
print_rank_0(f'rCache Priority Check: {priority}')
if not priority:
n_cache_blocks = max(4, math.ceil(self.max_checkpoint_size / self.public_block_size) + 1)
if self.comm_overlap:
n_cache_blocks += 2
n_cache_blocks = min(n_cache_blocks, self.public_block_number)
self.cuda_elements -= n_cache_blocks * self.public_block_size
if self.verbose:
print_rank_0(f'n_cache_block is set to {n_cache_blocks}')
else:
self.cuda_elements -= self.public_block_number * self.public_block_size
def try_move_chunk_to_cuda(fused: bool):
for (i, plan) in enumerate(chunk_plans):
rcache_fused = plan.kwargs.get('rcache_fused', False)
if not fused and rcache_fused:
continue
elif fused and not rcache_fused:
break
param_os_size = os_factor * plan.chunk_size // self.default_group_size
if self.cuda_elements >= param_os_size:
plan.kwargs['shard_device'] = gpu_device()
self.cuda_elements -= param_os_size
else:
plan.kwargs['shard_device'] = torch.device('cpu')
plan.kwargs['cpu_pin_memory'] = True
if self.verbose:
print_rank_0(f"chunk {i}: shard device -> {plan.kwargs['shard_device']}")
# check chunks that are not fused on rCache
try_move_chunk_to_cuda(False)
# check chunks that are fused on rCache
try_move_chunk_to_cuda(True)
if not priority:
extra_blocks = math.floor(self.cuda_elements / self.public_block_size)
extra_blocks = min(extra_blocks, self.public_block_number - n_cache_blocks)
self.cuda_elements -= extra_blocks * self.public_block_size
self.public_block_number = n_cache_blocks + extra_blocks
if self.verbose:
print_rank_0(f'n_extra_blocks is set to {extra_blocks}')
return chunk_plans
def optimal_search(
# pre-commit: do not rearrange
m: nn.Module,
group_size: int,
unified_dtype: torch.dtype = torch.float,
optimizer_type: str = 'Adam',
overlap: bool = False,
verbose: bool = False,
inp=None,
step_fn=None) -> SearchResult:
search_class = SearchOptimal(
# pre-commit: do not rearrange
module=m,
default_group_size=group_size,
dtype=unified_dtype,
verbose=verbose,
overlap=overlap,
inp=inp,
step_fn=step_fn)
private_group, public_groups = search_class.search()
chunk_plans = search_class.generate_chunk_plans(private_group, public_groups)
if unified_dtype == torch.float16:
master_weight_factor = 2
elif unified_dtype == torch.float:
master_weight_factor = 1
else:
raise NotImplementedError
if optimizer_type == 'SGD':
extra_sotre_factor = 1
elif optimizer_type == 'Adam':
extra_sotre_factor = 2
else:
raise NotImplementedError
os_factor = 1 + (1 + extra_sotre_factor) * master_weight_factor
chunk_plans = search_class.configure_rcache_size(chunk_plans, os_factor)
chunk_group = search_class.allocate_chunk_group(chunk_plans)
return SearchResult(chunk_group=chunk_group,
chunk_plans=chunk_plans,
param_called_per_step=search_class.param_per_step)

View File

@ -0,0 +1,32 @@
from typing import Dict, List, NamedTuple
import torch
from colossalai.elixir.chunk import ChunkGroup
class ChunkPlan(NamedTuple):
"""ChunkPlan is a type of configuration used to instruct the initialization of a chunk.
args:
name_list: contains the names of parameters that should be pushed into this chunk
chunk_size: the size of this chunk
chunk_dtype: the dtype of this chunk
kwargs: a dictionary used in __init__ function of Chunk
"""
name_list: List[str]
chunk_size: int
chunk_dtype: torch.dtype
kwargs: Dict
class SearchResult(object):
def __init__(self,
chunk_group: ChunkGroup,
chunk_plans: List[ChunkPlan],
param_called_per_step: List[List[str]] = None) -> None:
super().__init__()
self.chunk_group = chunk_group
self.param_chunk_plans = chunk_plans
self.param_called_per_step = param_called_per_step

View File

@ -0,0 +1,120 @@
import math
from typing import Tuple
import torch
import torch.nn as nn
from .base import SearchBase
from .result import SearchResult
from .utils import get_multi_used_params, to_divide
class SearchSimple(SearchBase):
"""The simple search algorithm used for unit tests.
Developers can specify the number of chunks used.
args:
module: the module to be searched
default_group_size: the default group size of communications
dtype: the data type of the parameters
prefetch: whether to prefetch the chunks
verbose: whether to print the search process
inp: the example input of the model
step_fn: the example step function of training
"""
def __init__(self,
module: nn.Module,
default_group_size: int,
dtype: torch.dtype = torch.float,
prefetch: bool = False,
verbose: bool = False,
inp=None,
step_fn=None) -> None:
super().__init__(module, dtype, prefetch, verbose, inp, step_fn)
self.default_group_size = default_group_size
def private_truncate(self, param: nn.Parameter) -> int:
return to_divide(param.numel(), self.default_group_size)
def public_trucate(self, length: int) -> int:
return to_divide(length, self.default_group_size)
def search(self, split_number: int, allocate_factor: float) -> Tuple:
# get multi-used parameters
private_params = get_multi_used_params(self.meta_module)
# get parameters used only one time
public_params = [p for p in self.meta_module.parameters() if p not in private_params]
# calculate the size of each group
len_public = len(public_params)
split_number = min(len_public, split_number)
# allocate a list for groups
public_groups = list()
if split_number > 0:
average_size = len_public // split_number
left_size = len_public % split_number
# set the size of each segment
pack_size_list = [average_size] * split_number
for i in range(split_number):
if left_size > 0:
pack_size_list[i] += 1
left_size -= 1
# split public parameters
for i in range(split_number):
p_list = list()
for _ in range(pack_size_list[i]):
p = public_params.pop(0)
p_list.append(p)
public_groups.append(p_list)
assert len(public_params) == 0
# calculate the maximum summarized size
max_sum_size = 0
for p_list in public_groups:
sum_size = sum([p.numel() for p in p_list])
max_sum_size = max(max_sum_size, sum_size)
else:
max_sum_size = 0
self.public_block_size = max_sum_size
self.public_block_number = math.ceil(split_number * allocate_factor)
return (private_params, public_groups)
def simple_search(m: nn.Module,
group_size: int,
split_number: int = 10,
allocate_factor: float = 0.6,
unified_dtype: torch.dtype = torch.float,
shard_device: torch.device = torch.device('cpu'),
prefetch: bool = False,
verbose: bool = False,
inp=None,
step_fn=None) -> SearchResult:
search_class = SearchSimple(
# pre-commit: do not rearrange
module=m,
default_group_size=group_size,
dtype=unified_dtype,
prefetch=prefetch,
verbose=verbose,
inp=inp,
step_fn=step_fn)
private_group, public_groups = search_class.search(split_number, allocate_factor)
chunk_plans = search_class.generate_chunk_plans(private_group, public_groups)
# assign shard device
for plan in chunk_plans:
plan.kwargs['shard_device'] = shard_device
chunk_group = search_class.allocate_chunk_group(chunk_plans)
return SearchResult(chunk_group=chunk_group,
chunk_plans=chunk_plans,
param_called_per_step=search_class.param_per_step)

View File

@ -0,0 +1,112 @@
import math
from .utils import to_divide
def calc_move_times(param_per_step: list, param_to_chunk: dict, n_blocks: int):
from colossalai.elixir.simulator import move_count
chunk_per_step = list()
for param_set in param_per_step:
id_set = set()
for name in param_set:
# continue if the parameter is ignored
if name not in param_to_chunk:
continue
id_set.add(param_to_chunk[name])
if len(id_set) > 0:
chunk_per_step.append(list(id_set))
return move_count(chunk_per_step, n_blocks)
def find_optimal_chunk_size(
# pre-commit: do not rearrange
param_per_step: list,
param_names: list,
param_numels: list,
cuda_elements: int,
overlap: bool,
min_range: int,
max_range: int,
interval: int):
max_numel = 0
for numel in param_numels:
max_numel = max(max_numel, numel)
test_size = to_divide(max(max_numel, min_range), interval)
# floor rounding
cuda_elements = to_divide(cuda_elements - interval + 1, interval)
max_range = min(max_range, cuda_elements)
min_move_elements = float('+inf')
best_size = test_size
best_number_blocks = 0
best_waste = 0
def dispatch_chunks(param_to_chunk: dict, block_size: int) -> int:
chunk_id = 0
acc = 0
left = 0
for (name, numel) in zip(param_names, param_numels):
if numel > left:
acc += left
chunk_id += 1
left = block_size
left -= numel
param_to_chunk[name] = chunk_id
return (chunk_id, left + acc)
assert test_size <= max_range, 'max_numel or min_range is larger than max_range or cuda capacity'
while test_size <= max_range:
# calculate the number of blocks
number_blocks = int(cuda_elements // test_size)
# if prefetch is enabled, we pretend that two chunks are reserved
if overlap:
number_blocks -= 2
if number_blocks <= 0:
continue
# initialize the chunk id for each parameter
param_to_chunk = dict()
number_chunks, current_waste = dispatch_chunks(param_to_chunk, test_size)
number_blocks = min(number_blocks, number_chunks)
# calculate the minimum number of movements
move_times = calc_move_times(param_per_step, param_to_chunk, number_blocks)
current_move_elements = move_times * test_size
# print("test", test_size, current_move_elements)
if current_move_elements < min_move_elements:
min_move_elements = current_move_elements
best_size = test_size
best_number_blocks = number_blocks
best_waste = current_waste
test_size += interval
if min_move_elements == float('inf'):
raise RuntimeError('optimal search: can not find a valid solution')
return best_size, best_number_blocks, best_waste
def bandwidth_c2g(n: int):
return 16.3 * n + 8.7
def bandwidth_g2c(n: int):
return 15.8 * n + 2.3
def velocity_gpu(n: int):
return 50 * n
def velocity_cpu(n: int):
return 1.66 * math.log(n) + 5.15
def rcache_prioirity_check(n: int, r_os: int, e_p: int, e_o: int):
In = e_p / bandwidth_c2g(n) + e_p / bandwidth_g2c(n)
Jn = (n / r_os) * (e_o / bandwidth_c2g(n) + In + e_p / bandwidth_g2c(n) + 1.0 / velocity_cpu(n) -
1.0 / velocity_gpu(n))
return In > Jn

View File

@ -0,0 +1,108 @@
from typing import List, Set
import torch
import torch.nn as nn
def to_divide(a: int, b: int):
return a + (-a % b)
def to_meta_tensor(t: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
# only float tensors need dtype change
if t.is_floating_point() and dtype is not None:
meta_dtype = dtype
else:
meta_dtype = t.dtype
# we shall not use t.data.to here, since t might be a fake tensor
meta_t = torch.empty(t.size(), dtype=meta_dtype, device='meta')
# pack it if t is a parameter
# we should filter parameters with no grad
if isinstance(t, nn.Parameter) and t.requires_grad:
meta_t = nn.Parameter(meta_t)
return meta_t
def get_multi_used_params(m: nn.Module) -> Set[torch.Tensor]:
multi_used_set = set()
visit = dict()
for module in m.modules():
for param in module.parameters(recurse=False):
if param not in visit:
visit[param] = True
else:
multi_used_set.add(param)
return multi_used_set
def find_minimum_waste_size(numel_group_list: List[List[int]], min_range: int, max_range: int, interval: int):
max_per_group = list()
for n_list in numel_group_list:
max_per_group.append(max(n_list))
max_numel = max(max_per_group)
test_size = to_divide(max(max_numel, min_range), interval)
best_size = test_size
min_waste = float('+inf')
def calc_waste(numel_list: List[int], block_size: int):
acc = 0
left = 0
for s in numel_list:
if s > left:
acc += left
left = block_size
left -= s
return left + acc
assert test_size <= max_range, 'max_numel or min_range is larger than max_range'
while test_size <= max_range:
current_waste = 0
for n_list in numel_group_list:
current_waste += calc_waste(n_list, test_size)
if current_waste < min_waste:
best_size = test_size
min_waste = current_waste
test_size += interval
return best_size, min_waste
def find_search_range(m: nn.Module):
ele_size = 0
for param in m.parameters():
if ele_size == 0:
ele_size = param.element_size()
else:
assert param.element_size() == ele_size
def next_2_pow(x: int):
y = 1
while y < x:
y <<= 1
return y
private_params = get_multi_used_params(m)
params = [p for p in m.parameters() if p not in private_params]
memo_list = [p.numel() * p.element_size() for p in params]
max_memo = max(memo_list)
# minimum chunk memory is 32 MiB
default_min = 32 * 1024**2
while default_min < max_memo:
default_min <<= 1
default_max = int(3 * default_min)
# * 2 for forward and backward
length = 2 * next_2_pow(len(params))
default_iter_times = 16 * 1024**2
default_search_times = default_iter_times // length
gap = default_max - default_min
# minimum search interval is 1024
if default_search_times > (gap // 1024):
interval = 1024
else:
interval = gap // default_search_times
return (default_min // ele_size, default_max // ele_size, interval // ele_size)

View File

@ -0,0 +1,61 @@
#include <Python.h>
#include <bits/stdc++.h>
#include <torch/extension.h>
int move_count_impl(std::vector<std::vector<int>> &steps, int n_blocks) {
int n_steps = steps.size();
std::unordered_map<int, int> my_map;
std::map<std::pair<int, int>, int> next_map;
for (auto i = n_steps - 1; ~i; --i) {
auto ids = steps.at(i);
for (auto c_id : ids) {
auto iter = my_map.find(c_id);
auto nxt = n_steps;
if (iter != my_map.end()) nxt = iter->second;
next_map.emplace(std::make_pair(i, c_id), nxt);
my_map[c_id] = i;
}
}
// reuse this map
for (auto iter : my_map) my_map[iter.first] = 0;
int cache_size = 0, count = 0;
std::priority_queue<std::pair<int, int>> cache;
for (auto i = 0; i < n_steps; ++i) {
auto ids = steps.at(i);
assert(n_blocks >= ids.size());
int not_in = 0;
for (auto c_id : ids)
if (my_map[c_id] == 0) ++not_in;
while (cache_size + not_in > n_blocks) {
std::pair<int, int> q_top = cache.top();
cache.pop();
assert(q_top.first > i);
assert(my_map[q_top.second] == 1);
my_map[q_top.second] = 0;
--cache_size;
++count;
}
for (auto c_id : ids) {
auto iter = next_map.find(std::make_pair(i, c_id));
cache.push(std::make_pair(iter->second, c_id));
if (my_map[c_id] == 0) {
my_map[c_id] = 1;
++cache_size;
}
}
}
return (count + cache_size) << 1;
}
int move_count(std::vector<std::vector<int>> &steps, int n_blocks) {
return move_count_impl(steps, n_blocks);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("move_count", &move_count, "Count the number of moves.");
}

112
colossalai/elixir/tensor.py Normal file
View File

@ -0,0 +1,112 @@
import torch
from torch.utils._pytree import tree_map
debug_flag = False
white_list = {torch.Tensor.__getitem__}
fake_allowed = {
# pre-commit: don't move
torch.Tensor.numel,
torch.Tensor.size,
torch.Tensor.stride,
torch.Tensor.storage_offset,
torch.Tensor.is_floating_point
}
inpalce_mapping = {
torch.Tensor.add_: torch.Tensor.add,
torch.Tensor.sub_: torch.Tensor.sub,
torch.Tensor.mul_: torch.Tensor.mul,
torch.Tensor.div_: torch.Tensor.div
}
def is_no_hook_op(func) -> bool:
if func.__name__.startswith('__') and func not in white_list:
return True
if func in fake_allowed:
return True
return False
class FakeTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad)
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise NotImplementedError
def to_outplace_tensor(t):
if isinstance(t, OutplaceTensor):
return t
assert type(t) is torch.Tensor, f'type: {type(t)}'
t.__class__ = OutplaceTensor
return t
class OutplaceTensor(torch.Tensor):
# TODO: rename this class
def __new__(cls, tensor):
rt = tensor.as_subclass(cls)
return rt
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 paramter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
# return a tensor if the output needs to be a torch.Tensor (such as Tensor.data.__get__)
if is_no_hook_op(func):
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
return ret
# debug inplace operations
if debug_flag:
if func.__name__.endswith('_'):
print(f'found inplace operation {func.__name__}')
# replace the in-place function
if func in inpalce_mapping:
func = inpalce_mapping[func]
# set the 'inplace' kwargs to False
if 'inplace' in kwargs:
kwargs['inplace'] = False
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if not isinstance(ret, tuple):
ret = (ret,)
def convert(t):
if isinstance(t, torch.Tensor):
t = to_outplace_tensor(t)
return t
ret = tree_map(convert, ret)
if len(ret) == 1:
ret = ret[0]
return ret

View File

View File

@ -0,0 +1,2 @@
from .cuda_profiler import cuda_memory_profiling
from .memory_tensor import MTensor

View File

@ -0,0 +1,77 @@
import gc
from typing import Callable, Dict, Tuple, Union
import torch
import torch.nn as nn
from torch.utils._pytree import tree_map
from colossalai.elixir.tracer.utils import get_cuda_allocated, meta_copy, model_memory_figure
from colossalai.elixir.utils import print_rank_0
from .memory_tensor import MTensor
def grad_cleaner(grad):
empty_grad = torch.empty_like(grad.elem)
grad.elem = None
empty_grad.storage().resize_(0)
return empty_grad
def cuda_memory_profiling(model: nn.Module, inp: Dict, step_fn: Callable, dtype=torch.float):
assert isinstance(inp, dict), 'the example input should be a dictionary'
print_rank_0(f'You are profiling cuda memory with dtype `{dtype}`')
def tensor_trans(t: torch.Tensor):
# set dtype for tensors
meta_dtype = dtype if t.is_floating_point() else t.dtype
meta_t = torch.empty_like(t.data, device='meta', dtype=meta_dtype)
# pack parameters
if isinstance(t, nn.Parameter):
meta_t = nn.Parameter(meta_t)
return meta_t
# first, transform the model into one dtype
model = meta_copy(model, tensor_trans)
# get the memory firgure of the model
memo_dict = model_memory_figure(model)
# initialize a empty pool for parameters
pool = torch.zeros(memo_dict['param_max_numel'], device='cuda', dtype=dtype)
def tensor_to_cuda(t):
if isinstance(t, nn.Parameter):
fake_data = pool[:t.numel()].view(t.shape)
return nn.Parameter(fake_data)
else:
fake_data = torch.zeros(t.shape, device='cuda', dtype=t.dtype)
return fake_data
# make all parameters in CUDA and point to a same address
model = meta_copy(model, tensor_to_cuda)
# add hooks to clean gradients
for param in model.parameters():
param.register_hook(grad_cleaner)
def input_trans(t):
if isinstance(t, torch.Tensor):
cuda_dtype = dtype if t.is_floating_point() else t.dtype
cuda_t = t.data.clone()
cuda_t = cuda_t.to(dtype=cuda_dtype, device='cuda')
cuda_t.requires_grad = t.requires_grad
return MTensor(cuda_t)
return t
inp = tree_map(input_trans, inp)
# reset all collected peak memory states
MTensor.reset_peak_memory()
before_cuda_alc = get_cuda_allocated()
step_fn(model, inp)
after_cuda_alc = MTensor.current_peak_memory()
activation_occ = after_cuda_alc - before_cuda_alc
return dict(param_occ=memo_dict['param_occ'],
buffer_occ=memo_dict['buffer_occ'],
grad_occ=memo_dict['param_occ'],
activation_occ=activation_occ)

View File

@ -0,0 +1,99 @@
import contextlib
from typing import Iterator
import torch
from torch.utils._pytree import tree_map
from colossalai.elixir.tracer.utils import get_cuda_max_allocated
from .op_cache import wrapped_mm_ops
aten = torch.ops.aten
mm_ops_list = [aten.mm.default, aten.addmm.default, aten.bmm.default, aten.addbmm.default, aten.baddbmm.default]
@contextlib.contextmanager
def no_dispatch() -> Iterator[None]:
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
class MTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
peak_memory_allocated: int = 0
@staticmethod
def reset_peak_memory():
torch.cuda.reset_peak_memory_stats()
MTensor.peak_memory_allocated = 0
@staticmethod
def update_peak_memory(new_peak):
MTensor.peak_memory_allocated = max(MTensor.peak_memory_allocated, new_peak)
@staticmethod
def current_peak_memory():
cur_peak = get_cuda_max_allocated()
return max(MTensor.peak_memory_allocated, cur_peak)
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
# TODO: clone strides and storage aliasing
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad)
r.elem = elem
return r
def __repr__(self):
return f'MTensor({self.elem})'
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def print_tensor(x):
if isinstance(x, torch.Tensor):
print(x.shape)
# tree_map(print_tensor, args)
# tree_map(print_tensor, kwargs)
def unwrap(x):
return x.elem if isinstance(x, MTensor) else x
def wrap(x):
return MTensor(x) if isinstance(x, torch.Tensor) else x
if func in mm_ops_list:
res, pre_max = wrapped_mm_ops(func, *tree_map(unwrap, args), **tree_map(unwrap, kwargs))
MTensor.update_peak_memory(pre_max)
else:
with no_dispatch():
res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
outs = normalize_tuple(res)
res = tree_map(wrap, outs)
if len(res) == 1:
return res[0]
else:
return res

View File

@ -0,0 +1,135 @@
import contextlib
from typing import Dict, Iterator, Tuple
import torch
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
from .output_shape import addmm_output, bmm_output, check_cuda_mm, mm_output
@contextlib.contextmanager
def no_dispatch() -> Iterator[None]:
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def tensor_info(x: torch.Tensor):
# returns the meta information used for CUDA kernels
return (x.shape, x.stride(), x.layout, x.dtype)
def get_args_info(*args):
# returns a tuple contains the meta information of all inputs
# every argument is expected to be a tensor
info_list = []
for x in args:
if isinstance(x, torch.Tensor):
info_list.append(tensor_info(x))
return tuple(info_list)
class OpCache(object):
def __init__(self, name: str) -> None:
super().__init__()
self.name = name
self.temp_memory: Dict[Tuple, int] = dict()
def reset(self):
self.temp_memory.clear()
def get(self, info):
if info in self.temp_memory:
return True, self.temp_memory[info]
else:
return False, None
def add(self, info, memo):
self.temp_memory[info] = memo
def print(self):
print(f'OpCache {self.name} information:')
for k, v in self.temp_memory.items():
print(f'key: {k}\ntemp_memo:{v}')
aten = torch.ops.aten
addmm_cache = OpCache('aten.addmm.default')
bmm_cache = OpCache('aten.bmm.default')
mm_cache = OpCache('aten.mm.default')
op_mapping = {
aten.mm.default: {
'cache': mm_cache,
'output': mm_output
},
aten.addmm.default: {
'cache': addmm_cache,
'output': addmm_output
},
aten.bmm.default: {
'cache': bmm_cache,
'output': bmm_output
}
}
def reset_caches():
addmm_cache.reset()
bmm_cache.reset()
mm_cache.reset()
def fake_cuda_output(temp_memo, output_shape, dtype):
ret = torch.empty(output_shape, dtype=dtype, device='cuda')
sub = temp_memo - ret.numel() * ret.element_size()
if sub > 0:
# allocate a temp empty tensor block to simulate the computation in kernels
temp = torch.empty(sub, dtype=torch.int8, device='cuda')
# release this tensor block
del temp
return ret
def real_cuda_output(func, *args, **kwargs):
cur_alc = get_cuda_allocated()
# save the peak memory usage
pre_max_alc = get_cuda_max_allocated()
# the peak memory history is cleared here
torch.cuda.reset_peak_memory_stats()
with no_dispatch():
ret = func(*args, **kwargs)
max_alc = get_cuda_max_allocated()
# calculate the temporary memory allocation
temp_memo = max_alc - cur_alc
return ret, temp_memo, pre_max_alc
def wrapped_mm_ops(func, *args, **kwargs):
check_cuda_mm(*args)
if func not in op_mapping:
raise RuntimeError(f'Unsupported mm operation {func}')
args_info = get_args_info(*args)
cache = op_mapping[func]['cache']
cached_flag, temp_memo = cache.get(args_info)
if cached_flag:
output_fn = op_mapping[func]['output']
out_shape = output_fn(*args)
ret = fake_cuda_output(temp_memo=temp_memo, output_shape=out_shape, dtype=args[0].dtype)
return ret, 0
else:
ret, temp_memo, pre_max_alc = real_cuda_output(func, *args, **kwargs)
cache.add(args_info, temp_memo)
return ret, pre_max_alc

View File

@ -0,0 +1,48 @@
import torch
# Output functions come from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
def check_cuda_mm(*args):
for x in args:
assert isinstance(x, torch.Tensor)
assert x.device.type == 'cuda'
def mm_output(a, b):
assert a.dim() == 2, 'a must be 2D'
assert b.dim() == 2, 'b must be 2D'
N, M1 = a.shape
M2, P = b.shape
assert M1 == M2, 'a and b must have same reduction dim'
return (N, P)
def addmm_output(bias, x, y):
return mm_output(x, y)
def common_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
assert batch1.dim() == 3, 'batch1 must be a 3D tensor'
assert batch2.dim() == 3, 'batch2 must be a 3D tensor'
batch1_sizes = batch1.size()
batch2_sizes = batch2.size()
bs = batch1_sizes[0]
contraction_size = batch1_sizes[2]
res_rows = batch1_sizes[1]
res_cols = batch2_sizes[2]
output_size = (bs, res_rows, res_cols)
assert batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size
if not is_bmm and self_baddbmm is not None:
assert self_baddbmm.dim() == 3, 'self must be a 3D tensor'
assert self_baddbmm.size() == output_size, \
f'Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}'
return output_size
def bmm_output(mat1, mat2):
return common_baddbmm_bmm(mat1, mat2, True)

View File

@ -0,0 +1,76 @@
import torch
import torch.distributed as dist
aten = torch.ops.aten
__all__ = [
'TorchFactoryMethod', 'TorchOverrideableFactoryMethod', 'TorchNonOverrideableFactoryMethod', 'TensorPropertyMethod',
'DistCommMethod', 'AliasATen', 'InplaceATen', 'MaybeInplaceAten', 'SameStorageAten'
]
TorchOverrideableFactoryMethod = [
'empty',
'eye',
'full',
'ones',
'rand',
'randn',
'zeros',
]
TorchNonOverrideableFactoryMethod = [
'arange',
'finfo',
'linspace',
'logspace',
'randint',
'randperm',
'tensor',
]
TorchFactoryMethod = TorchOverrideableFactoryMethod + TorchNonOverrideableFactoryMethod
TensorPropertyMethod = ['dtype', 'shape', 'device', 'requires_grad', 'grad', 'grad_fn', 'data']
DistCommMethod = [
'all_gather',
'all_reduce',
'all_to_all',
'broadcast',
'gather',
'reduce',
'reduce_scatter',
'scatter',
]
AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]
MaybeInplaceAten = [
aten.diagonal.default,
aten.select.int,
aten.slice.Tensor,
aten.as_strided.default,
]
SameStorageAten = AliasATen + InplaceATen + MaybeInplaceAten

View File

@ -0,0 +1,3 @@
from .fx_order import generate_fx_order
from .td_order import generate_td_order
from .tf_order import generate_tf_order

View File

@ -0,0 +1,62 @@
from typing import Dict, List
import torch
import torch.nn as nn
from torch.fx import GraphModule, Node, symbolic_trace
from colossalai.elixir.tracer.utils import meta_copy
def generate_fx_order(model: nn.Module) -> List[Dict[str, nn.Parameter]]:
fxf_name_mark = '_fxf_name'
fxf_param_mark = '_fxf_param'
def tensor_trans(t):
meta_t = t.data.to('meta')
if isinstance(t, nn.Parameter):
meta_t = nn.Parameter(meta_t)
return meta_t
meta_model = meta_copy(model, tensor_trans)
# attach names for parameters
for name, param in meta_model.named_parameters():
setattr(param, fxf_name_mark, name)
fx_forward_order: List[Dict[str, nn.Parameter]] = list()
gm: GraphModule = symbolic_trace(meta_model)
for node in gm.graph.nodes:
if node.op in ('output', 'placeholder'):
continue
step_dict = None
if node.op == 'get_attr':
maybe_param = getattr(gm, node.target)
# mark this node as a parameter
if maybe_param is not None:
setattr(node, fxf_param_mark, maybe_param)
continue
elif node.op == 'call_module':
target_module = gm.get_submodule(node.target)
step_dict = dict()
# collect all parameters in the module
for maybe_param in target_module.parameters():
if maybe_param is not None:
param_name = getattr(maybe_param, fxf_name_mark)
step_dict[param_name] = maybe_param
elif node.op in ('call_function', 'call_method'):
step_dict = dict()
for pre in node.args:
if hasattr(pre, fxf_param_mark):
param = getattr(pre, fxf_param_mark)
param_name = getattr(param, fxf_name_mark)
step_dict[param_name] = param
else:
raise RuntimeError(f'Unsupported node op {node.op}!')
if step_dict is not None and len(step_dict) > 0:
fx_forward_order.append(step_dict)
return fx_forward_order

View File

@ -0,0 +1,156 @@
import contextlib
import uuid
from typing import Callable, Dict, Iterator, List, Tuple, Union
import torch
import torch.nn as nn
from torch.utils._pytree import tree_map
from colossalai.elixir.tracer.ops import SameStorageAten
from colossalai.elixir.tracer.utils import meta_copy
@contextlib.contextmanager
def no_dispatch() -> Iterator[None]:
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def register_storage(x):
assert isinstance(x, nn.Parameter)
assert x.data_ptr() == 0
data_ptr = uuid.uuid1()
x.data_ptr = lambda: data_ptr
class ATensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
data_ptr_dict: Dict[int, Tuple[str, nn.Parameter]] = None
order_list: List[Dict] = None
@staticmethod
def reset():
ATensor.data_ptr_dict = dict()
ATensor.order_list = list()
@staticmethod
def clear():
ATensor.data_ptr_dict = None
ATensor.order_list = None
@staticmethod
def add_data_ptr(name: str, param: nn.Parameter):
data_ptr = param.data_ptr()
if data_ptr not in ATensor.data_ptr_dict:
ATensor.data_ptr_dict[data_ptr] = (name, param)
else:
name_in, param_in = ATensor.data_ptr_dict[data_ptr]
if name != name_in or id(param) != id(param_in):
raise RuntimeError('Got two different parameters with the same data ptr')
@staticmethod
def get_param(data_ptr: int):
if data_ptr in ATensor.data_ptr_dict:
return ATensor.data_ptr_dict.get(data_ptr)
else:
return None, None
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
# TODO: clone strides and storage aliasing
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad)
r.elem = elem
return r
def __repr__(self):
return f'ATensor({self.elem})'
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
step_dict = dict()
def record_param(x):
if isinstance(x, torch.Tensor):
name, param = ATensor.get_param(x.data_ptr())
if name is not None:
step_dict[name] = param
def debug_tensor(x):
if isinstance(x, torch.Tensor):
print(type(x), x.shape, x.data_ptr(), id(x))
if x.grad_fn:
print(x.grad_fn)
tree_map(record_param, args)
if len(step_dict) > 0:
ATensor.order_list.append(step_dict)
del step_dict
def unwrap(x):
return x.elem if isinstance(x, ATensor) else x
def wrap(x):
return ATensor(x) if isinstance(x, torch.Tensor) else x
with no_dispatch():
res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
outs = normalize_tuple(res)
res = tree_map(wrap, outs)
if func in SameStorageAten:
for x in res:
if isinstance(x, torch.Tensor):
x.data_ptr = args[0].data_ptr
if len(res) == 1:
return res[0]
else:
return res
def generate_td_order(model: nn.Module, inp: Union[torch.Tensor, Tuple], step_fn: Callable):
ATensor.reset()
def tensor_trans(t):
meta_t = ATensor(t.data.to('meta'))
if isinstance(t, nn.Parameter):
meta_t = nn.Parameter(meta_t)
return meta_t
model = meta_copy(model, tensor_trans)
for name, param in model.named_parameters():
register_storage(param)
ATensor.add_data_ptr(name, param)
# convert all input data to meta_tensor
if not isinstance(inp, tuple):
inp = (inp,)
inp = tree_map(lambda t: ATensor(torch.empty_like(t, device='meta', requires_grad=t.requires_grad)), inp)
step_fn(model, inp)
ret = ATensor.order_list
ATensor.clear()
return ret

View File

@ -0,0 +1,178 @@
from typing import Callable, Dict, List
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.utils._pytree import tree_map
from colossalai.elixir.tensor import is_no_hook_op
from colossalai.elixir.tracer.utils import meta_copy
from colossalai.elixir.utils import no_dispatch, normalize_tuple
torch_checkpoint_function = torch.utils.checkpoint.checkpoint
def attach_checkpoint():
default_in_checkpoint = False
def inner_checkpoint_function(function, *args, use_reentrant: bool = True, **kwargs):
nonlocal default_in_checkpoint
prev_in_checkpoint = default_in_checkpoint
default_in_checkpoint = True
# record the step where going into checkpoint
if not prev_in_checkpoint:
Record.record_in_checkpoint()
# use original torch checkpoint function
global torch_checkpoint_function
ret = torch_checkpoint_function(function, *args, use_reentrant=use_reentrant, **kwargs)
# roll back
default_in_checkpoint = prev_in_checkpoint
if not default_in_checkpoint:
Record.record_out_checkpoint()
return ret
torch.utils.checkpoint.checkpoint = inner_checkpoint_function
def release_checkpoint():
global torch_checkpoint_function
torch.utils.checkpoint.checkpoint = torch_checkpoint_function
class PostFwdPreBwd(torch.autograd.Function):
@staticmethod
def forward(ctx, params, *args):
ctx.params = params
return args
@staticmethod
def backward(ctx, *grads):
Record.record_params(ctx.params)
return (None, *grads)
class Record(nn.Parameter):
record_steps: List = None
checkpoint_info: List = None
in_checkpoint_step: int = -1
def __new__(cls, elem):
assert elem.device.type == 'meta', f'device type: {elem.device.type}'
r = torch.Tensor._make_subclass(cls, elem)
return r
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if is_no_hook_op(func):
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
return ret
params = list()
def append_param(x):
if isinstance(x, nn.Parameter):
assert isinstance(x, Record)
params.append(x)
tree_map(append_param, args)
tree_map(append_param, kwargs)
Record.record_params(params)
with torch._C.DisableTorchFunction():
ret = normalize_tuple(func(*args, **kwargs))
ret = PostFwdPreBwd.apply(params, *ret)
def clone(t):
if isinstance(t, torch.Tensor):
t = t.clone()
return t
ret = tree_map(clone, ret)
if len(ret) == 1:
return ret[0]
else:
return ret
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# notice: we should disable __torch_function__ here
# otherwise, unexpected operations are called inside meta kernels
with torch._C.DisableTorchFunction():
with no_dispatch():
return func(*args, **kwargs)
@staticmethod
def reset():
Record.record_steps = list()
Record.checkpoint_info = list()
Record.in_checkpoint_step = -1
@staticmethod
def record_in_checkpoint():
assert Record.in_checkpoint_step == -1
Record.in_checkpoint_step = len(Record.record_steps)
@staticmethod
def record_out_checkpoint():
assert Record.in_checkpoint_step != -1
value_pair = (Record.in_checkpoint_step, len(Record.record_steps))
Record.checkpoint_info.append(value_pair)
Record.in_checkpoint_step = -1
@staticmethod
def steps():
ret = dict(params_per_step=Record.record_steps, checkpoint_info=Record.checkpoint_info)
Record.record_steps = None
Record.checkpoint_info = None
return ret
@staticmethod
def record_params(params):
record_dict = {p.param_name for p in params}
Record.record_steps.append(record_dict)
def generate_tf_order(model: nn.Module, inp: Dict, step_fn: Callable, dtype: torch.dtype = torch.float):
assert isinstance(inp, dict), 'The example input should be a dictionary'
Record.reset()
def mtensor_trans(t: torch.Tensor):
if t.is_floating_point():
meta_dtype = dtype
else:
meta_dtype = t.dtype
meta_t = torch.empty_like(t, dtype=meta_dtype, device='meta')
if isinstance(t, nn.Parameter):
meta_t = Record(meta_t)
meta_t.requires_grad = t.requires_grad
return meta_t
model = meta_copy(model, mtensor_trans)
for name, param in model.named_parameters():
param.param_name = name
def input_trans(t):
if isinstance(t, torch.Tensor):
if t.is_floating_point():
meta_dtype = dtype
else:
meta_dtype = t.dtype
meta_t = torch.empty_like(t, dtype=meta_dtype, device='meta', requires_grad=t.requires_grad)
return meta_t
return t
inp = tree_map(input_trans, inp)
attach_checkpoint()
step_fn(model, inp)
release_checkpoint()
ret = Record.steps()
return ret

View File

@ -0,0 +1,96 @@
from collections import OrderedDict
from copy import copy
from typing import Optional, Set
import torch
import torch.nn as nn
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
"""
if memo is None:
memo = set()
if module not in memo:
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
yield m
memo.add(module)
yield prefix, module
def _get_shallow_copy_model(model: nn.Module):
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
old_to_new = dict()
for name, module in _get_dfs_module_list(model):
new_module = copy(module)
new_module._modules = OrderedDict()
for subname, submodule in module._modules.items():
if submodule is None:
continue
setattr(new_module, subname, old_to_new[submodule])
old_to_new[module] = new_module
return old_to_new[model]
def meta_copy(model: nn.Module, meta_fn: callable):
new_model = _get_shallow_copy_model(model)
old_parameters = dict()
old_buffers = dict()
for (_, old_module), (_, new_module) in \
zip(_get_dfs_module_list(model), _get_dfs_module_list(new_model)):
new_module._parameters = OrderedDict()
for name, param in old_module._parameters.items():
new_param = None
if param is not None:
param_id = id(param)
if param_id in old_parameters:
new_param = old_parameters.get(param_id)
else:
new_param = meta_fn(param)
old_parameters[param_id] = new_param
setattr(new_module, name, new_param)
new_module._buffers = OrderedDict()
for name, buffer in old_module._buffers.items():
new_buffer = None
if buffer is not None:
buffer_id = id(buffer)
if buffer_id in old_buffers:
new_buffer = old_buffers.get(buffer_id)
else:
new_buffer = meta_fn(buffer)
old_buffers[buffer_id] = new_buffer
new_module.register_buffer(name, new_buffer)
return new_model
def get_cuda_allocated():
return torch.cuda.memory_allocated()
def get_cuda_max_allocated():
return torch.cuda.max_memory_allocated()
def model_memory_figure(model: nn.Module):
param_occ = 0
max_numel = 0
for name, param in model.named_parameters():
param_occ += param.numel() * param.element_size()
max_numel = max(max_numel, param.numel())
buffer_occ = 0
for name, buffer in model.named_buffers():
buffer_occ += buffer.numel() * buffer.element_size()
return dict(param_occ=param_occ, param_max_numel=max_numel, buffer_occ=buffer_occ)

119
colossalai/elixir/utils.py Normal file
View File

@ -0,0 +1,119 @@
import contextlib
import os
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
@contextlib.contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def init_distributed():
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
init_method = f'tcp://[{host}]:{port}'
dist.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
# set cuda device
if torch.cuda.is_available():
# if local rank is not given, calculate automatically
torch.cuda.set_device(local_rank)
seed_all(1024)
def print_rank_0(*args, **kwargs):
if dist.is_initialized():
if dist.get_rank() == 0:
print(*args, **kwargs)
dist.barrier()
else:
print(*args, **kwargs)
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel
def model_size_formatter(numel: int) -> str:
GB_SIZE = 10**9
MB_SIZE = 10**6
KB_SIZE = 10**3
if numel >= GB_SIZE:
return f'{numel / GB_SIZE:.1f}B'
elif numel >= MB_SIZE:
return f'{numel / MB_SIZE:.1f}M'
elif numel >= KB_SIZE:
return f'{numel / KB_SIZE:.1f}K'
else:
return str(numel)
def calc_buffer_size(m: nn.Module, test_dtype: torch.dtype = torch.float):
max_sum_size = 0
for module in m.modules():
sum_p_size = 0
for param in module.parameters(recurse=False):
assert param.dtype == test_dtype
sum_p_size += param.numel()
max_sum_size = max(max_sum_size, sum_p_size)
return max_sum_size
def calc_block_usage():
snap_shot = torch.cuda.memory_snapshot()
total_sum = 0
active_sum = 0
for info_dict in snap_shot:
blocks = info_dict.get('blocks')
for b in blocks:
size = b.get('size')
state = b.get('state')
total_sum += size
if state == 'active_allocated':
active_sum += size
active_ratio = 1
if total_sum > 0:
active_ratio = active_sum / total_sum
print(f'memory snap shot: active ratio {active_ratio:.2f}')

View File

@ -0,0 +1,2 @@
from .module import ElixirModule
from .optimizer import ElixirOptimizer

View File

@ -0,0 +1,344 @@
from collections import defaultdict
from copy import copy
from functools import partial
from typing import Any, Iterable, Mapping
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.utils._pytree import tree_map
from colossalai.elixir.chunk import Chunk, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.search import SearchResult
from colossalai.elixir.tensor import OutplaceTensor
from colossalai.utils.model.experimental import LazyTensor
def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype):
param_data = param_data.to(gpu_device())
optim_data = param_data.clone() if param_data.dtype == torch.float else param_data.float()
param_data = param_data.to(param_dtype)
return param_data, optim_data
class ElixirModule(nn.Module):
"""Use this class to wrap your model when using Elixir. Don't know what should be written here.
But some docstring is needed here.
args:
module: training module
search_result: a SearchResult generated from a search algorithm in `elixir.search`
process_group: the communication group, ussually dp parallel group
prefetch: whether to use prefetch overlaping communication with computation
dtype: the dtype used in training
"""
def __init__(self,
module: nn.Module,
search_result: SearchResult,
process_group: ProcessGroup,
prefetch: bool = False,
dtype: torch.dtype = torch.float,
reduce_always_fp32: bool = False,
output_fp32: bool = False,
use_fused_kernels: bool = False) -> None:
super().__init__()
assert dtype in {torch.float, torch.float16}
self._set_module_outplace(module)
self.module = module
self.dtype = dtype
self.use_amp = (dtype == torch.float16)
self.process_group = process_group
self.prefetch_flag = prefetch
self.reduce_always_fp32 = reduce_always_fp32
self.output_fp32 = output_fp32
self.use_fused_kernels = use_fused_kernels
self.no_grad_state_dict = dict()
self.grad_state_dict = dict()
self.__init_chunk_group(search_result)
self.__init_chunk_fetcher(search_result, prefetch)
self.__init_buffer_storage()
for name, param in module.named_parameters():
if not param.requires_grad:
assert name in self.no_grad_state_dict
continue
assert name in self.grad_state_dict
param.register_hook(partial(self._gradient_handler, param=param))
param.__class__ = HookParam
def __init_chunk_group(self, sr: SearchResult):
torch.cuda.empty_cache()
state_dict = self.module.state_dict(keep_vars=True)
for name, tensor in state_dict.items():
if isinstance(tensor, nn.Parameter):
assert tensor.is_floating_point(), 'the dtypes of parameters should be float dtypes'
# deal with parameters
if tensor.requires_grad:
self.grad_state_dict[name] = tensor
else:
self.no_grad_state_dict[name] = tensor
# polish no-grad parameters
tensor.data = tensor.data.to(dtype=self.dtype, device=gpu_device())
else:
# deal with buffers
self._lazy_init_check(tensor)
to_dtype = self.dtype if tensor.is_floating_point() else tensor.dtype
tensor.data = tensor.data.to(dtype=to_dtype, device=gpu_device())
empty_mp = MemoryPool('cuda')
empty_mp.allocate()
self.param_chunk_group = sr.chunk_group
self.optim_chunk_group = ChunkGroup(empty_mp)
self.param_to_optim = dict()
vis_set = set()
for plan in sr.param_chunk_plans:
assert plan.chunk_dtype == self.dtype
# optimizer chunks should not be gathered
optim_kwargs = copy(plan.kwargs)
if 'rcache_fused' in optim_kwargs:
optim_kwargs['rcache_fused'] = False
p_chunk = self.param_chunk_group.open_chunk(chunk_size=plan.chunk_size,
chunk_dtype=plan.chunk_dtype,
process_group=self.process_group,
chunk_config=plan.kwargs)
o_chunk = self.optim_chunk_group.open_chunk(chunk_size=plan.chunk_size,
chunk_dtype=torch.float,
process_group=self.process_group,
chunk_config=optim_kwargs)
for name in plan.name_list:
param = self.grad_state_dict[name]
self._lazy_init_check(param)
param_data, optim_data = get_param_optim_data(param.data, self.dtype)
param.data = param_data
p_chunk.append_tensor(param)
o_chunk.append_tensor(optim_data)
self.param_to_optim[param] = optim_data
vis_set.add(param)
self.param_chunk_group.close_chunk(p_chunk)
self.optim_chunk_group.close_chunk(o_chunk)
p_chunk.init_pair(o_chunk)
# sanity check: every parameter needed gradient has been initialized
for param in self.module.parameters():
if param.requires_grad:
assert param in vis_set
def __init_chunk_fetcher(self, sr: SearchResult, prefetch: bool):
scheduler = None
if prefetch:
assert sr.param_called_per_step is not None
chunk_called_per_step = list()
for step in sr.param_called_per_step:
step_set = set()
for name in step:
param = self.grad_state_dict[name]
chunk = self.param_chunk_group.ten_to_chunk[param]
step_set.add(chunk)
chunk_called_per_step.append(step_set)
scheduler = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step)
else:
scheduler = FIFOScheduler()
self.fetcher = ChunkFetcher(scheduler,
self.param_chunk_group,
overlap=prefetch,
reduce_always_fp32=self.reduce_always_fp32)
self.fetcher.reset()
def __init_buffer_storage(self):
buffer_size = 0
for submodule in self.modules():
sum_param_size = 0
for param in submodule.parameters(recurse=False):
if not param.requires_grad or self.fetcher.is_in_fused(param):
continue
assert param.dtype == self.dtype
sum_param_size += param.numel()
buffer_size = max(buffer_size, sum_param_size)
self.buffer = BufferStore(buffer_size, self.dtype)
print('module buffer', self.buffer)
def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter):
# create an empty tensor
fake_grad = self.buffer.empty_like(grad)
with torch._C.DisableTorchFunction():
chunk = self.fetcher.get_one_chunk(param)
assert self.fetcher.group.is_accessed(chunk)
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError()
self.fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(param, grad)
self.fetcher.reduce_chunk(chunk)
return fake_grad
def _lazy_init_check(self, tensor: torch.Tensor) -> None:
if isinstance(tensor, LazyTensor):
tensor.materialize()
def _set_module_outplace(self, m: nn.Module):
# set inplace to False for all modules
for module in m.modules():
if hasattr(module, 'inplace'):
module.inplace = False
def _deattach_fetcher(self):
self.fetcher.clear()
HookParam.release_fetcher()
HookParam.disable_fused_kernel()
def _release_for_inference(self):
torch.cuda.synchronize()
scheduler = self.fetcher.scheduler
param_group = self.param_chunk_group
while True:
maybe_chunk = scheduler.top()
if maybe_chunk is None:
break
scheduler.remove(maybe_chunk)
param_group.release_chunk(maybe_chunk)
self._deattach_fetcher()
def forward(self, *args, **kwargs):
if torch.is_grad_enabled():
inference_mode = False
else:
inference_mode = True
# reset the fetcher in this step
self.fetcher.reset()
HookParam.attach_fetcher(self.fetcher, self.buffer)
if self.use_fused_kernels:
HookParam.enable_fused_kernel()
def to_outplace_tensor(t):
if isinstance(t, torch.Tensor):
if t.is_floating_point():
t = t.to(self.dtype)
t = OutplaceTensor(t)
return t
args = tree_map(to_outplace_tensor, args)
kwargs = tree_map(to_outplace_tensor, kwargs)
outputs = self.module(*args, **kwargs)
if self.output_fp32:
outputs = outputs.float()
if inference_mode:
self._release_for_inference()
return outputs
def backward(self, loss: torch.Tensor):
loss.backward()
# reset the fetcher for the next step
self._deattach_fetcher()
# reset all attributes
self.module.zero_grad(set_to_none=True)
def state_dict(self,
destination=None,
prefix='',
keep_vars=False,
only_rank_0: bool = False,
from_param: bool = False):
assert keep_vars is False, 'state_dict can not keep variables in ElixirModule'
# make sure that the variables are kept, we shall detach them later
module_state_dict = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=True)
tensor_to_names = defaultdict(list)
for name, tensor in module_state_dict.items():
if isinstance(tensor, nn.Parameter) and tensor.requires_grad:
used_tensor = self.grad_state_dict[name]
if not from_param:
used_tensor = self.param_to_optim.get(used_tensor)
tensor_to_names[used_tensor].append(name)
else:
module_state_dict[name] = tensor.detach()
def update_state_dict(chunks: Iterable[Chunk]):
for c in chunks:
for op, cp in zip(c.get_tensors(), c.get_cpu_copy(only_rank_0)):
for name in tensor_to_names.get(op):
module_state_dict[name] = cp
if from_param:
used_group = self.param_chunk_group
else:
used_group = self.optim_chunk_group
update_state_dict(used_group.fused_chunks)
update_state_dict(used_group.float_chunks)
return module_state_dict
def load_state_dict(self, state_dict: Mapping[str, Any], only_rank_0: bool = False):
load_flag = not only_rank_0 or dist.get_rank() == 0
if not load_flag:
# only rank 0 loads the state dict
assert state_dict is None
if only_rank_0:
# broadcast the length of the state dict
state_length = len(state_dict) if load_flag else None
comm_list = [state_length]
dist.broadcast_object_list(comm_list)
state_length = comm_list[0]
# broadcast the keys of the state dict
state_keys = state_dict.keys() if load_flag else [None] * state_length
dist.broadcast_object_list(state_keys)
# update the state dict
if not load_flag:
state_dict = {k: None for k in state_keys}
# init the mapping from optim tensor to load tensor
optim_to_load = dict()
for name, maybe_tensor in state_dict.items():
if name in self.no_grad_state_dict:
no_grad_param = self.no_grad_state_dict.get(name)
if load_flag:
no_grad_param.copy_(maybe_tensor)
if only_rank_0:
dist.broadcast(no_grad_param, src=0)
elif name in self.grad_state_dict:
grad_param = self.grad_state_dict.get(name)
optim_tensor = self.param_to_optim.get(grad_param)
optim_to_load[optim_tensor] = maybe_tensor
def use_state_dict(chunks: Iterable[Chunk]):
for c in chunks:
load_tensor_list = list()
has_load = False
for chunk_tensor in c.tensors_info.keys():
if chunk_tensor in optim_to_load:
has_load = True
load_tensor_list.append(optim_to_load[chunk_tensor])
else:
load_tensor_list.append(None)
if has_load:
c.load_tensors(load_tensor_list, only_rank_0)
c.paired_chunk.optim_update()
use_state_dict(self.optim_chunk_group.fused_chunks)
use_state_dict(self.optim_chunk_group.float_chunks)
return

View File

@ -0,0 +1,265 @@
import math
from collections import defaultdict
from enum import Enum
from typing import Dict, Set, Tuple
import torch
import torch.distributed as dist
from torch.nn import Parameter
import colossalai.nn.optimizer as colo_optim
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler, ConstantGradScaler, DynamicGradScaler
from colossalai.elixir.chunk import Chunk
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.hook.storage import BufferStore
from colossalai.logging import get_dist_logger
from .module import ElixirModule
_AVAIL_OPTIM_LIST = {colo_optim.FusedAdam, colo_optim.CPUAdam, colo_optim.HybridAdam}
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
"""A wrapper for optimizers. Users should notice that one specific ElixirOptimizer is strictly
corresponding to one ElixirModule. Currently only a group of optimizers are supported in ElixirOptimizer.
The reason is that ElixirOptimizer only support element-wise optimizers now.
We may enlarge the group of supported optimizers later.
Args:
optim: The torch optimizer instance.
module: The nn.Module instance wrapped as an ElixirModule.
"""
def __init__(self,
module: ElixirModule,
optimizer: torch.optim.Optimizer,
initial_scale: float = 32768,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**24,
max_norm: float = 0.0,
norm_type: float = 2.0,
init_step=False):
super().__init__(optimizer)
assert isinstance(module, ElixirModule)
self.scaled_optimizer = False
if type(optimizer) in _AVAIL_OPTIM_LIST:
self.scaled_optimizer = True
self.module = module
self.param_chunk_group = module.param_chunk_group
self.optim_chunk_group = module.optim_chunk_group
self.optim_state = OptimState.UNSCALED
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_optim_chunk: Dict[Parameter, Chunk] = dict()
self.param_chunk_set: Set[Chunk] = self.param_chunk_group.fused_chunks.union(
self.param_chunk_group.float_chunks)
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
if self.clipping_flag:
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
self.__init__optimizer()
# Grad scaler
self.grad_scaler: BaseGradScaler = None
if module.use_amp:
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
else:
self.grad_scaler = ConstantGradScaler(1.0, verbose=False)
self._comm_buffer: torch.Tensor = torch.zeros(1, dtype=torch.float, device=gpu_device())
self._logger = get_dist_logger()
if init_step:
# allocate memory before training
self.__zero_step()
def __zero_step(self):
torch.cuda.empty_cache()
cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu')
buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer)
for _, zero_buffer in buffer_dict.items():
zero_buffer.zeros()
for group in self.param_groups:
for fake_param in group['params']:
optim_chunk = self.param_to_optim_chunk[fake_param]
begin, end = self.param_to_range[fake_param]
fake_param.data = buffer_dict.get(optim_chunk.shard_device.type).empty_1d(end - begin)
fake_param.grad = fake_param.data
fake_param.data = optim_chunk.shard[begin:end]
self.optim.step()
self.zero_grad()
self._update_fp16_params(update_flag=False)
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group['params']:
optim_chunk = self.param_to_optim_chunk[fake_param]
begin, end = self.param_to_range[fake_param]
param_chunk = optim_chunk.paired_chunk
fake_param.data = param_chunk.shard[begin:end]
fake_param.grad = fake_param.data
fake_param.data = optim_chunk.shard[begin:end]
def _update_fp16_params(self, update_flag: bool = True):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group['params']:
assert fake_param.grad is None
fake_param.data = none_tensor.to(fake_param.device)
if update_flag:
for param_chunk in self.param_chunk_set:
param_chunk.optim_update()
def _check_overflow(self) -> bool:
# calculate the overflow counter
overflow_counter = 0
for param_chunk in self.param_chunk_set:
overflow_counter += int(param_chunk.overflow)
return overflow_counter > 0
def _clear_optim_states(self) -> None:
for param_chunk in self.param_chunk_set:
param_chunk.overflow = False
param_chunk.l2_norm = None
def _calc_global_norm(self) -> float:
group_to_norm = defaultdict(float)
for param_chunk in self.param_chunk_set:
assert param_chunk.l2_norm is not None
assert not param_chunk.is_replica
group_to_norm[param_chunk.torch_pg] += param_chunk.l2_norm
norm_sqr = 0.0
for group, part_norm in group_to_norm.items():
self._comm_buffer.fill_(part_norm)
dist.all_reduce(self._comm_buffer, group=group)
norm_sqr += self._comm_buffer.item()
global_norm = math.sqrt(norm_sqr)
return global_norm
def _get_combined_scale(self):
loss_scale = 1
assert self.optim_state == OptimState.SCALED
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * loss_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def zero_grad(self, *args, **kwargs):
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
self._set_grad_ptr()
found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self._clear_optim_states() # clear chunk states used for optimizer update
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
self._clear_optim_states()
if not self.scaled_optimizer:
assert combined_scale == -1, 'You should use an optimizer in the available list:\n' \
f'{_AVAIL_OPTIM_LIST}'
ret = self.optim.step(*args, **kwargs)
else:
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self.zero_grad()
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
self.module.backward_by_grad(tensor, grad)
def __init__optimizer(self):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
for group in self.param_groups:
fake_params_list = list()
for param in group['params']:
if not param.requires_grad:
continue
param_chunk = self.module.fetcher.get_one_chunk(param)
range_pair = get_range_pair(param_chunk, param)
if range_pair[0] >= range_pair[1]:
continue
grad_device = param_chunk.shard.device
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
self.param_to_range[fake_param] = range_pair
fake_params_list.append(fake_param)
group['params'] = fake_params_list

View File

@ -10,3 +10,5 @@ contexttimer
ninja
torch>=1.11
safetensors
sortedcontainers
einops

View File

@ -16,7 +16,7 @@ from op_builder.utils import (
try:
import torch
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
@ -30,7 +30,11 @@ BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1
IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1
# a variable to store the op builder
ext_modules = []
ext_modules = [
CppExtension(name='colossalai.elixir.simulator',
sources=['colossalai/elixir/simulator.cpp'],
extra_compile_args=['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'])
]
# we do not support windows currently
if sys.platform == 'win32':

View File

View File

@ -0,0 +1,89 @@
import torch
import torch.distributed as dist
import colossalai
from colossalai.elixir import ElixirModule, ElixirOptimizer
from colossalai.elixir.search import minimum_waste_search
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def check_elixir_compatibility(early_stop: bool = True):
"""check gemini plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
"""
passed_models = []
failed_info = {} # (model_name, error) pair
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
'torchaudio_wav2vec2_base', 'torchaudio_hubert_base', 'torchvision_convnext_base'):
continue
try:
print(name)
global_size = dist.get_world_size()
global_group = dist.GroupMember.WORLD
model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
for k, v in data.items()
}
sr = minimum_waste_search(
# pre-commit: do not rearrange
m=model,
group_size=global_size,
unified_dtype=torch.float16,
prefetch=False,
verbose=True)
model = ElixirModule(model, sr, global_group, prefetch=False, dtype=torch.float16)
optimizer = ElixirOptimizer(model, optimizer, initial_scale=32)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
optimizer.backward(loss)
optimizer.step()
passed_models.append(name)
del model, optimizer, criterion, data, output, loss
except Exception as e:
failed_info[name] = e
if early_stop:
raise e
torch.cuda.empty_cache()
if dist.get_rank() == 0:
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
def run_dist(rank, world_size, port, early_stop: bool = True):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_elixir_compatibility(early_stop=early_stop)
@rerun_if_address_is_in_use()
def exam_compatibility(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == '__main__':
exam_compatibility(early_stop=False)

View File

View File

@ -0,0 +1,72 @@
from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.chunk.scheduler import FIFOScheduler
from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.tensor import OutplaceTensor
def to_divide(a: int, b: int):
return a + (-a % b)
def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher):
empty_grad = torch.empty_like(grad)
empty_grad.storage().resize_(0)
with torch._C.DisableTorchFunction():
chunk = fetcher.get_one_chunk(param)
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError()
fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(param, grad)
fetcher.reduce_chunk(chunk)
return empty_grad
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
pg_size = dist.get_world_size(process_group)
private_list = list()
for param in model.parameters():
block_size = to_divide(param.numel(), pg_size)
private_list.append(BlockRequire(block_size, param.dtype))
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private_list)
cg = ChunkGroup(rcache=mp)
# allocate chunk group
fused_config = dict(rcache_fused=True)
for param in model.parameters():
cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config)
# initialize chunk fetcher
scheduler = FIFOScheduler()
fetcher = ChunkFetcher(scheduler, cg)
buffer = BufferStore(0, torch.float32)
# register fetcher and gradient handler
HookParam.attach_fetcher(fetcher, buffer)
for param in model.parameters():
param.register_hook(partial(grad_handler, param=param, fetcher=fetcher))
param.__class__ = HookParam
# set inplace to False for all modules
for module in model.modules():
if hasattr(module, 'inplace'):
module.inplace = False
def transform_input(self_module, inputs):
fetcher.reset()
input_list = list()
for t in inputs:
if isinstance(t, torch.Tensor):
t = OutplaceTensor(t)
input_list.append(t)
return tuple(input_list)
model.register_forward_pre_hook(transform_input)
return model, cg

View File

@ -0,0 +1,63 @@
import torch
from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock
from colossalai.testing import run_on_environment_flag
@run_on_environment_flag('ELX')
def test_block():
b = PublicBlock(123, torch.float16, 'cuda')
payload_b = b.payload
assert payload_b.numel() == 123
assert payload_b.dtype == torch.float16
assert payload_b.device.type == 'cuda'
assert payload_b.numel() * payload_b.element_size() == b.memo_occ
c = PrivateBlock(77, torch.float, 'cpu')
payload_c = c.payload
assert payload_c.numel() == 77
assert payload_c.dtype == torch.float
assert payload_c.device.type == 'cpu'
assert payload_c.numel() * payload_c.element_size() == c.memo_occ
print('test_block: ok')
@run_on_environment_flag('ELX')
def test_memory_pool():
mp = MemoryPool(device_type='cuda')
private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)]
mp.allocate(public_block_number=4, private_block_list=private_list)
block0 = mp.get_public_block()
assert block0 in mp.public_used_blocks
assert mp.public_used_cnt == 1
assert mp.public_free_cnt == 3
block1 = mp.get_public_block()
assert block1 in mp.public_used_blocks
assert mp.public_used_cnt == 2
assert mp.public_free_cnt == 2
mp.free_public_block(block0)
mp.free_public_block(block1)
assert block0 in mp.public_free_blocks
assert block1 in mp.public_free_blocks
assert mp.public_used_cnt == 0
assert mp.public_free_cnt == 4
block0 = mp.get_private_block(5, torch.float)
assert block0.numel == 5
assert block0.dtype == torch.float
print('test_memory_pool: ok')
if __name__ == '__main__':
test_block()
test_memory_pool()

View File

@ -0,0 +1,155 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_chunk_functions(nproc, group):
a = torch.randn(2, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 2, 128, device='cuda')
copy_b = b.clone()
c = torch.randn(128, device='cuda')
copy_c = c.clone()
d = torch.randn(4, 32, device='cuda')
copy_d = d.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_number=1)
chunk = Chunk(mp, 1024, torch.float, group)
chunk.l2_norm_flag = True
assert chunk.chunk_size == 1024
assert chunk.chunk_dtype == torch.float
assert chunk.shard_size == 1024 // nproc
def check_tensors():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
assert torch.equal(c, copy_c)
assert torch.equal(d, copy_d)
chunk.append_tensor(a)
chunk.append_tensor(b)
chunk.append_tensor(c)
chunk.append_tensor(d)
check_tensors()
chunk.close_chunk()
assert chunk.is_replica is False
# check function: get_cpu_copy
cpu_copys = chunk.get_cpu_copy()
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
assert t_cpu.device.type == 'cpu'
assert torch.equal(t_gpu.cpu(), t_cpu)
# check function: access_chunk
block = mp.get_public_block()
chunk.access_chunk(block)
assert chunk.is_replica
assert chunk.scatter_check
check_tensors()
# check function: release_chunk
chunk.optim_sync_flag = False
block = chunk.release_chunk()
assert block in mp.public_used_blocks
assert chunk.is_replica is False
assert chunk.optim_sync_flag is True
# check function: access_chunk after release_chunk
chunk.access_chunk(block)
check_tensors()
# check function: reduce_chunk
norm = block.payload.float().norm(2)**2
chunk.reduce_chunk()
assert chunk.is_replica is False
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
test_norm = torch.Tensor([chunk.l2_norm]).cuda()
dist.all_reduce(test_norm)
assert torch.allclose(norm, test_norm)
torch.cuda.synchronize()
print('chunk functions are ok')
def exam_chunk_states(nproc, group):
a = torch.randn(2, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 2, 128, device='cuda')
copy_b = b.clone()
c = torch.randn(128, device='cuda')
copy_c = c.clone()
d = torch.randn(4, 32, device='cuda')
copy_d = d.clone()
private = [BlockRequire(1024, torch.float)]
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private)
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
assert chunk.chunk_size == 1024
assert chunk.chunk_dtype == torch.float
assert chunk.shard_size == 1024 // nproc
def check_tensors():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
assert torch.equal(c, copy_c)
assert torch.equal(d, copy_d)
chunk.append_tensor(a)
chunk.append_tensor(b)
chunk.append_tensor(c)
chunk.append_tensor(d)
check_tensors()
chunk.close_chunk()
assert chunk.is_replica is False
chunk.access_chunk()
assert chunk.is_replica
check_tensors()
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
chunk.tensor_trans_state(a, TensorState.COMPUTE)
assert chunk.tensor_state_cnter[TensorState.HOLD] == 3
assert chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
tensor_list = [a, b, c, d]
for t in tensor_list:
chunk.tensor_trans_state(t, TensorState.COMPUTE)
chunk.tensor_trans_state(t, TensorState.HOLD_AFTER_BWD)
chunk.tensor_trans_state(t, TensorState.READY_FOR_REDUCE)
assert chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
assert chunk.reduce_check
torch.cuda.synchronize()
print('chunk states are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_functions(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_functions(world_size=4)

View File

@ -0,0 +1,71 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.chunk import ChunkGroup
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def check_gradient(ddp_model, my_model, cg: ChunkGroup):
for chunk in cg.fused_chunks:
cg.access_chunk(chunk)
for (name, p0), p1 in zip(ddp_model.named_parameters(), my_model.parameters()):
torch.cuda.synchronize()
print(f'checking parameter {name}')
assert_close(p0.grad.data, p1.data)
def exam_chunk_fetcher(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
torch_model = model_fn().cuda()
test_model = copy.deepcopy(torch_model)
rank = dist.get_rank(group)
# get different data
seed_all(1001 + rank)
data = to_cuda(data_fn())
seed_all(1001, cuda_deterministic=True)
ddp_model = DDP(torch_model)
ddp_loss = ddp_model(**data)
ddp_loss.backward()
hook_model, cg = hook_transform(test_model, group)
my_loss = hook_model(**data)
my_loss.backward()
assert_close(ddp_loss, my_loss)
check_gradient(ddp_model, hook_model, cg)
print('private chunk fetcher is ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_fetcher(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_fetcher(world_size=2)

View File

@ -0,0 +1,98 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_chunk_group_functions(nproc, group):
a = torch.randn(3, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 32, device='cuda')
copy_b = b.clone()
c = torch.randn(256, device='cuda')
copy_c = c.clone()
d = torch.randn(2, 2, 64, device='cuda')
copy_d = d.clone()
e = torch.randn(2, 33, device='cuda')
copy_e = e.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
cg = ChunkGroup(rcache=mp)
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
c1 = cg.allocate_chunk([c], 256, torch.float, group)
c2 = cg.allocate_chunk([d], 256, torch.float, group)
fused_config = dict(rcache_fused=True)
c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config)
def check_chunk_0():
assert torch.equal(a, copy_a)
assert torch.equal(b, copy_b)
def check_chunk_1():
assert torch.equal(c, copy_c)
def check_chunk_2():
assert torch.equal(d, copy_d)
def check_chunk_3():
assert torch.equal(e, copy_e)
# check tensors_to_chunks
chunks = cg.tensors_to_chunks([e, a])
assert chunks[0] == c0
assert chunks[1] == c3
# check access_chunk for unfused chunks
cg.access_chunk(c0)
cg.access_chunk(c1)
check_chunk_0()
check_chunk_1()
assert not cg.rcache_enough_check(c2)
assert cg.rcache_enough_check(c3)
# check access_chunk for fused chunks
cg.access_chunk(c3)
check_chunk_3()
# check release_chunk for unfused chunks
cg.release_chunk(c1)
assert cg.rcache_enough_check(c2)
# check access_chunk
cg.access_chunk(c2)
check_chunk_2()
cg.tensor_trans_state(e, TensorState.COMPUTE)
cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD)
cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE)
cg.reduce_chunk(c3)
assert not c3.is_replica
torch.cuda.synchronize()
print('chunk group functions are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_group(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_group(world_size=2)

View File

@ -0,0 +1,130 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import Chunk, MemoryPool
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
def exam_fifo(nproc, group):
mp = MemoryPool('cuda')
mp.allocate(public_block_number=1)
c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group)
sdl = FIFOScheduler()
sdl.reset()
sdl.add(c0)
sdl.add(c1)
sdl.add(c2)
sdl.add(c0) # nothing happens here
assert sdl.top() == c0
sdl.remove(c0)
assert sdl.top() == c1, f'{sdl.top()}'
sdl.remove(c0)
assert sdl.top() == c1, f'{sdl.top()}'
sdl.add(c0)
assert sdl.top() == c1
sdl.remove(c1)
assert sdl.top() == c2
sdl.remove(c2)
assert sdl.top() == c0
def exam_prefetch(nproc, group):
mp = MemoryPool('cuda')
mp.allocate()
c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group)
chunk_called_per_step = [[c0], [c1], [c2], [c0], [c0], [c1], [c2], [c2], [c1], [c0]]
sdl = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step)
print(sdl.next_step_dict)
sdl.reset()
sdl.step()
sdl.add(c0)
assert sdl.top() == c0
sdl.step()
sdl.add(c1)
assert sdl.top() == c1
sdl.step()
sdl.add(c2)
assert sdl.top() == c2
sdl.remove(c0)
sdl.step()
sdl.add(c0)
assert sdl.top() == c2
sdl.remove(c0)
sdl.step()
sdl.add(c0)
assert sdl.top() == c0
sdl.remove(c0) # notice here
sdl.remove(c1)
sdl.step()
sdl.add(c1)
assert sdl.top() == c1
sdl.remove(c2)
sdl.step()
sdl.add(c2)
assert sdl.top() == c1
sdl.remove(c2)
sdl.step()
sdl.add(c2)
assert sdl.top() == c2
sdl.remove(c2) # notice here
sdl.add(c0) # notice here
sdl.remove(c1)
sdl.step()
sdl.add(c1)
assert sdl.top() == c1
sdl.remove(c1) # notice here
sdl.remove(c0)
sdl.step()
sdl.add(c0)
assert sdl.top() == c0
sdl.remove(c0)
sdl.clear()
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@run_on_environment_flag('ELX')
def test_chunk_scheduler(world_size=1):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_chunk_scheduler()

View File

@ -0,0 +1,22 @@
from colossalai.elixir.ctx import MetaContext
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
@run_on_environment_flag('ELX')
def test_meta_context():
builder, *_ = TEST_MODELS.get('resnet')
with MetaContext():
model = builder()
for name, param in model.named_parameters():
assert param.device.type == 'meta'
print(name, param)
for name, buffer in model.named_buffers():
assert buffer.device.type == 'meta'
print(name, buffer)
if __name__ == '__main__':
test_meta_context()

View File

@ -0,0 +1,56 @@
from copy import deepcopy
import torch
import torch.nn as nn
from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.tensor import FakeTensor
def test_hook():
x = nn.Parameter(torch.randn(4, 4))
ori_numel = x.numel()
ori_size = x.size()
ori_stride = x.stride()
ori_offset = x.storage_offset()
fake_data = FakeTensor(x.data)
x.data = fake_data
x.__class__ = HookParam
assert x.numel() == ori_numel
assert x.size() == ori_size
assert x.stride() == ori_stride
assert x.storage_offset() == ori_offset
def test_store():
buffer = BufferStore(1024, torch.float16)
print(buffer)
x = torch.randn(4, 128, dtype=torch.float16, device='cuda')
original_ptr_x = x.data_ptr()
copy_x = deepcopy(x)
y = torch.randn(512, dtype=torch.float16, device='cuda')
original_ptr_y = y.data_ptr()
copy_y = deepcopy(y)
offset = 0
offset = buffer.insert(x, offset)
assert offset == x.numel()
assert torch.equal(x, copy_x)
offset = buffer.insert(y, offset)
assert offset == 1024
assert torch.equal(y, copy_y)
buffer.erase(x)
buffer.erase(y)
assert x.data_ptr() == original_ptr_x
assert y.data_ptr() == original_ptr_y
if __name__ == '__main__':
test_store()

View File

@ -0,0 +1,35 @@
from copy import deepcopy
import pytest
from torch.testing import assert_close
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def exam_one_model(model_fn, data_fn):
from colossalai.elixir.kernels.attn_wrapper import wrap_attention
torch_model = model_fn().cuda()
test_model = deepcopy(torch_model)
test_model = wrap_attention(test_model)
data = to_cuda(data_fn())
torch_out = torch_model(**data)
torch_out.backward()
test_out = test_model(**data)
test_out.backward()
assert_close(torch_out, test_out)
for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()):
assert_close(p_torch.grad, p_test.grad)
@pytest.mark.skip(reason="Need to install xformers")
def test_gpt_atten_kernel():
exam_one_model(*TEST_MODELS.get('gpt2_micro'))
exam_one_model(*TEST_MODELS.get('opt_micro'))
if __name__ == '__main__':
test_gpt_atten_kernel()

View File

@ -0,0 +1,58 @@
import os
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed
from colossalai.elixir.wrapper import ElixirModule
def exam_fused_layernorm(nproc, group):
torch_model = nn.LayerNorm(2048)
fused_model = deepcopy(torch_model)
torch_model = torch_model.cuda()
sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True)
fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True)
data = torch.randn(2, 2048, device='cuda')
torch_loss = torch_model(data).sum()
torch_loss.backward()
fused_loss = fused_model(data).sum()
fused_model.backward(fused_loss)
assert_close(torch_loss, fused_loss)
grad_state = fused_model.state_dict(from_param=True)
for name, param in torch_model.named_parameters():
assert_close(param.grad.cpu(), grad_state[name])
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1])
@pytest.mark.skip(reason='need to install apex')
def test_fused_layernorm(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_fused_layernorm(world_size=1)

View File

@ -0,0 +1,38 @@
from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import minimum_waste_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_mini_waste_search():
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
model = model_fn()
data = data_fn()
sr = minimum_waste_search(model,
1,
unified_dtype=torch.float16,
cpu_offload=True,
prefetch=True,
verbose=True,
inp=data,
step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
for plan in chunk_plans:
assert plan.chunk_dtype == torch.float16
assert plan.kwargs.get('shard_device') == torch.device('cpu')
assert plan.kwargs.get('cpu_pin_memory') == True
if __name__ == '__main__':
test_mini_waste_search()

View File

@ -0,0 +1,30 @@
from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import optimal_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_optimal_search():
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
model = model_fn()
data = data_fn()
sr = optimal_search(model, 1, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
for plan in chunk_plans:
assert plan.chunk_dtype == torch.float16
assert plan.kwargs.get('shard_device') == gpu_device()
if __name__ == '__main__':
test_optimal_search()

View File

@ -0,0 +1,48 @@
from copy import deepcopy
import torch
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
def step_fn(model, inp):
model(**inp).backward()
@run_on_environment_flag('ELX')
def test_simple_search():
model_fn, data_fn = TEST_MODELS.get('small')
model = model_fn()
data = data_fn()
sr = simple_search(model,
1,
split_number=5,
shard_device=gpu_device(),
prefetch=True,
verbose=True,
inp=data,
step_fn=step_fn)
chunk_plans = deepcopy(sr.param_chunk_plans)
private_plan = chunk_plans.pop(0)
assert private_plan.name_list == ['embed.weight']
assert private_plan.chunk_size == 320
assert private_plan.kwargs.get('shard_device') == gpu_device()
assert chunk_plans[0].name_list == ['norm1.weight', 'norm1.bias']
assert chunk_plans[1].name_list == ['mlp.proj1.weight', 'mlp.proj1.bias']
assert chunk_plans[2].name_list == ['mlp.proj2.weight', 'mlp.proj2.bias']
assert chunk_plans[3].name_list == ['norm2.weight']
assert chunk_plans[4].name_list == ['norm2.bias']
for plan in chunk_plans:
assert plan.chunk_size == 1088
assert plan.kwargs.get('shard_device') == gpu_device()
if __name__ == '__main__':
test_simple_search()

View File

@ -0,0 +1,13 @@
from colossalai.elixir.simulator import move_count
from colossalai.testing import run_on_environment_flag
@run_on_environment_flag('ELX')
def test_move_count():
steps = [[0], [1, 2], [3], [3], [1, 2], [0]]
size = 2
assert move_count(steps, size) == 12
if __name__ == '__main__':
test_move_count()

View File

@ -0,0 +1,23 @@
import pytest
import torch
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import to_cuda
@run_on_environment_flag('ELX')
def test_registry():
from tests.test_elixir.utils.registry import TEST_MODELS
for name, model_tuple in TEST_MODELS:
torch.cuda.synchronize()
print(f'model `{name}` is in testing')
model_fn, data_fn = model_tuple
model = model_fn().cuda()
data = to_cuda(data_fn())
loss = model(**data)
loss.backward()
if __name__ == '__main__':
test_registry()

View File

@ -0,0 +1,46 @@
import pytest
import torch
from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def one_step(model, inp):
loss = model(**inp)
loss.backward()
return loss
def try_one_model(model_fn, data_fn):
model = model_fn().cuda()
data = to_cuda(data_fn())
one_step(model, data) # generate gradients
pre_cuda_alc = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
one_step(model, data)
aft_cuda_alc = torch.cuda.max_memory_allocated()
torch_activation_occ = aft_cuda_alc - pre_cuda_alc
model.zero_grad(set_to_none=True)
print('normal', torch_activation_occ)
before = torch.cuda.memory_allocated()
profiling_dict = cuda_memory_profiling(model, data, one_step)
after = torch.cuda.memory_allocated()
print('profiling', profiling_dict)
assert before == after
assert torch_activation_occ == profiling_dict['activation_occ']
print('Check is ok.')
@run_on_environment_flag('ELX')
def test_cuda_profiler():
model_list = ['resnet', 'gpt2_micro']
for name in model_list:
model_fn, data_fn = TEST_MODELS.get(name)
try_one_model(model_fn, data_fn)
if __name__ == '__main__':
test_cuda_profiler()

View File

@ -0,0 +1,145 @@
import pytest
import torch
from colossalai.elixir.tracer.memory_tracer import MTensor
from colossalai.elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
from colossalai.testing import run_on_environment_flag
def op_mm(x, y):
u = torch.matmul(x, y)
return u.shape
def op_addmm(x, y, z):
u = torch.addmm(x, y, z)
return u.shape
def op_bmm(x, y):
u = torch.bmm(x, y)
return u.shape
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_mm(dtype, size0=(4, 256), size1=(256, 1024)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_mm(x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_mm(x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(mm_cache.temp_memory) > 0
MTensor.reset_peak_memory()
op2_z_size = op_mm(x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_addmm(dtype, size0=(4, 16), size1=(16, 64)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
u = torch.randn(size1[-1], dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_addmm(u, x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
del u
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_addmm(u, x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(addmm_cache.temp_memory) > 0
MTensor.reset_peak_memory()
op2_z_size = op_addmm(u, x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_bmm(x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_bmm(x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(bmm_cache.temp_memory) > 0
bmm_cache.print()
MTensor.reset_peak_memory()
op2_z_size = op_bmm(x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
if __name__ == '__main__':
test_addmm(dtype=torch.float)

View File

@ -0,0 +1,37 @@
from colossalai.elixir.tracer.param_tracer import generate_tf_order
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
@run_on_environment_flag('ELX')
def test_tf_forward_backward():
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
model = model_fn()
data = data_fn()
def forward_backward_fn(local_model, local_input):
local_model(**local_input).backward()
# model.gradient_checkpointing_enable()
tf_order = generate_tf_order(model, data, forward_backward_fn)
params_per_step = tf_order['params_per_step']
assert len(params_per_step) == 32
model.gradient_checkpointing_enable()
tf_order = generate_tf_order(model, data, forward_backward_fn)
params_per_step = tf_order['params_per_step']
checkpoint_info = tf_order['checkpoint_info']
for i, step in enumerate(params_per_step):
print(f'step {i}: {step}')
for c in checkpoint_info:
print(f'checkpoint info: {c}')
assert len(params_per_step) == 44
assert data['input_ids'].device.type == 'cpu'
assert data['attention_mask'].device.type == 'cpu'
for param in model.parameters():
assert param.device.type == 'cpu'
if __name__ == '__main__':
test_tf_forward_backward()

View File

@ -0,0 +1,94 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def amp_check_model_states(ddp_optim, test_model):
test_states = test_model.state_dict()
for (name, _), p in zip(test_model.module.named_parameters(), amp.master_params(ddp_optim)):
test_p = test_states[name]
copy_p = p.to(test_p.device)
print(f'checking parameter `{name}`: {test_p.dtype} {copy_p.dtype}')
assert_close(test_p.data, copy_p.data)
def exam_amp_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
# important here, since apex has a lazy fp32 init after the first optimizer step
test_model = test_model.half()
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
ddp_model, ddp_optim = amp.initialize(ddp_model,
ddp_optim,
opt_level='O2',
loss_scale=1.0,
keep_batchnorm_fp32=False)
ddp_model = DDP(ddp_model, message_size=0, allreduce_always_fp32=True)
print("ok")
exit(0)
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
sr = simple_search(test_model, nproc, shard_device=gpu_device(), unified_dtype=torch.float16, verbose=True)
test_model = ElixirModule(test_model, sr, group, dtype=torch.float16, reduce_always_fp32=True, output_fp32=True)
test_optim = ElixirOptimizer(test_model, test_optim, initial_scale=1.0)
# get different data
seed_all(exam_seed + dist.get_rank(group), cuda_deterministic=True)
for _ in range(2):
data = to_cuda(data_fn())
ddp_optim.zero_grad()
ddp_loss = ddp_model(**data)
with amp.scale_loss(ddp_loss, ddp_optim) as scaled_loss:
scaled_loss.backward()
ddp_optim.step()
test_optim.zero_grad()
test_loss = test_model(**data)
test_optim.backward(test_loss)
test_optim.step()
assert_close(ddp_loss, test_loss)
amp_check_model_states(ddp_optim, test_model)
def exam_amp_in_models(nproc, group):
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
exam_amp_one_model(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_amp_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_elixir_amp(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_elixir_amp(world_size=2)

View File

@ -0,0 +1,95 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, assert_dict_values, to_cuda
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
grad_state = test_model.state_dict(from_param=True)
for name, param in ddp_model.named_parameters():
assert_close(param.grad.cpu(), grad_state[name])
def exam_module_init(nproc, group, grad_flag):
model_fn, data_fn = TEST_MODELS.get('resnet')
torch_model = model_fn().cuda()
test_model = model_fn().cuda()
for p1, p2 in zip(torch_model.parameters(), test_model.parameters()):
p1.requires_grad = p2.requires_grad = grad_flag
sr = simple_search(test_model, nproc)
model = ElixirModule(test_model, sr, group)
# check function: ElixirModule.load_state_dict after ElixirModule.__init__
torch_st = torch_model.state_dict()
if dist.get_rank() != 0:
torch_st = None
test_st = model.load_state_dict(torch_st, only_rank_0=True)
# check function: ElixirModule.state_dict after ElixirModule.__init__
torch_st = torch_model.state_dict()
test_st = model.state_dict()
assert_dict_values(torch_st, test_st, fn=torch.equal)
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2261):
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
sr = simple_search(test_model, nproc, allocate_factor=0.6)
test_model = ElixirModule(test_model, sr, group)
# get different data
seed_all(exam_seed + dist.get_rank(group))
data = data_fn()
data = to_cuda(data)
seed_all(exam_seed, cuda_deterministic=True)
ddp_model = DDP(ddp_model)
ddp_loss = ddp_model(**data)
ddp_loss.backward()
test_loss = test_model(**data)
test_model.backward(test_loss)
assert_close(ddp_loss, test_loss)
check_gradient(ddp_model.module, test_model)
def exam_modules_fwd_bwd(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=False)
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=True)
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_elixir_module(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_elixir_module(world_size=2)

View File

@ -0,0 +1,77 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, allclose, assert_dict_values, to_cuda
def exam_optimizer_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
ddp_model = DDP(ddp_model)
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
sr = simple_search(test_model, nproc, shard_device=gpu_device())
test_model = ElixirModule(test_model, sr, group)
test_optim = ElixirOptimizer(test_model, test_optim)
# get different data
seed_all(exam_seed + dist.get_rank(group))
data = to_cuda(data_fn())
seed_all(exam_seed, cuda_deterministic=True)
ddp_optim.zero_grad()
ddp_loss = ddp_model(**data)
ddp_loss.backward()
ddp_optim.step()
test_optim.zero_grad()
test_loss = test_model(**data)
test_optim.backward(test_loss)
test_optim.step()
assert_close(ddp_loss, test_loss)
torch_st = ddp_model.module.state_dict()
test_st = test_model.state_dict()
assert_dict_values(torch_st, test_st, fn=partial(allclose, rtol=2e-6, atol=2e-5))
def exam_optimizer_in_models(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
exam_optimizer_one_model(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_optimizer_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_elixir_optimizer(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_elixir_optimizer(world_size=4)

View File

@ -0,0 +1,89 @@
import copy
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.search import simple_search
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.elixir.wrapper import ElixirModule
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS, to_cuda
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
grad_state = test_model.state_dict(from_param=True)
for name, param in ddp_model.named_parameters():
assert_close(param.grad.cpu(), grad_state[name])
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2263):
def one_step(local_model, local_input):
loss = local_model(**local_input)
loss.backward()
return loss
ddp_model = model_fn().cuda()
test_model = copy.deepcopy(ddp_model)
# get different data
seed_all(exam_seed + dist.get_rank(group))
data = to_cuda(data_fn())
# wrap as DDP model
ddp_model = DDP(ddp_model)
# search how to initialize chunks
sr = simple_search(test_model,
nproc,
shard_device=gpu_device(),
prefetch=True,
verbose=True,
inp=data,
step_fn=one_step)
test_model = ElixirModule(test_model, sr, group, prefetch=True)
seed_all(exam_seed, cuda_deterministic=True)
ddp_loss = one_step(ddp_model, data)
with torch.no_grad():
test_loss = test_model(**data)
assert_close(ddp_loss, test_loss)
test_loss = test_model(**data)
test_model.backward(test_loss)
assert_close(ddp_loss, test_loss)
check_gradient(ddp_model.module, test_model)
def exam_modules_fwd_bwd(nproc, group):
model_fn, data_fn = TEST_MODELS.get('resnet')
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_module_prefetch(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_module_prefetch(world_size=2)

View File

@ -0,0 +1,41 @@
import torch
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from . import gpt, mlp, opt, resnet, small
from .registry import TEST_MODELS
def to_cuda(input_dict):
def local_fn(t):
if isinstance(t, torch.Tensor):
t = t.cuda()
return t
ret = tree_map(local_fn, input_dict)
return ret
def allclose(ta, tb, **kwargs):
assert_close(ta, tb, **kwargs)
return True
def assert_dict_keys(test_dict, keys):
assert len(test_dict) == len(keys)
for k in keys:
assert k in test_dict
def assert_dict_values(da, db, fn):
assert len(da) == len(db)
for k, v in da.items():
assert k in db
if not torch.is_tensor(v):
continue
u = db.get(k)
if u.device != v.device:
v = v.to(u.device)
# print(f"checking key {k}: {u.shape} vs {v.shape}")
assert fn(u.data, v.data), f'max diff {torch.max(torch.abs(u.data - v.data))}'

View File

@ -0,0 +1,79 @@
from functools import partial
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from tests.test_elixir.utils.registry import TEST_MODELS
MICRO_VS = 128
MICRO_BS = 4
MICRO_SL = 64
MACRO_VS = 50257
MACRO_BS = 2
MACRO_SL = 1024
def micro_data_fn():
input_ids = torch.randint(low=0, high=MICRO_VS, size=(MICRO_BS, MICRO_SL))
attn_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attn_mask)
def small_data_fn():
input_ids = torch.randint(low=0, high=MACRO_VS, size=(MACRO_BS, MACRO_SL))
attn_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attn_mask)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
class GPTLMModel(nn.Module):
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
super().__init__()
self.enable_gc = False
self.config = GPT2Config(
# pre-commit: do not rearrange
n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0)
self.module = GPT2LMHeadModel(config=self.config)
self.criterion = GPTLMLoss()
def gradient_checkpointing_enable(self):
self.module.gradient_checkpointing_enable()
self.enable_gc = True
def forward(self, input_ids, attention_mask):
# Only return lm_logits
output = self.module(input_ids=input_ids, attention_mask=attention_mask, use_cache=(not self.enable_gc))[0]
loss = self.criterion(output, input_ids)
return loss
gpt2_micro = partial(GPTLMModel, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
gpt2_small = GPTLMModel
gpt2_base = partial(GPTLMModel, hidden_size=1024, num_layers=24, num_attention_heads=16)
TEST_MODELS.register('gpt2_micro', gpt2_micro, micro_data_fn)
TEST_MODELS.register('gpt2_small', gpt2_small, small_data_fn)
TEST_MODELS.register('gpt2_base', gpt2_base, small_data_fn)

View File

@ -0,0 +1,34 @@
import torch
import torch.nn as nn
from tests.test_elixir.utils.registry import TEST_MODELS
def mlp_data_fn():
return dict(x=torch.randn(4, 16))
class MlpModule(nn.Module):
def __init__(self, hidden_dim: int = 16) -> None:
super().__init__()
self.proj1 = nn.Linear(hidden_dim, 4 * hidden_dim)
self.act = nn.GELU()
self.proj2 = nn.Linear(4 * hidden_dim, hidden_dim)
def forward(self, x):
return x + (self.proj2(self.act(self.proj1(x))))
class MlpModel(nn.Module):
def __init__(self, hidden_dim: int = 16) -> None:
super().__init__()
self.mlp = MlpModule(hidden_dim)
def forward(self, x):
output = self.mlp(x)
return output.sum()
TEST_MODELS.register('mlp', MlpModel, mlp_data_fn)

View File

@ -0,0 +1,46 @@
import torch.nn as nn
from transformers import OPTConfig, OPTForCausalLM
from tests.test_elixir.utils.registry import TEST_MODELS
from .gpt import micro_data_fn
class OPTLMModel(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.module = OPTForCausalLM(config=config)
self.enable_gc = False
def gradient_checkpointing_enable(self):
self.module.gradient_checkpointing_enable()
self.enable_gc = True
def forward(self, input_ids, attention_mask):
loss = self.module(
# pre-commit: do not rearrange
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
use_cache=(not self.enable_gc))['loss']
return loss
def opt_micro():
opt_config = OPTConfig(
# pre-commit: do not rearrange
vocab_size=128,
activation_dropout=0.0,
dropout=0,
hidden_size=32,
num_hidden_layers=4,
ffn_dim=128,
num_attention_heads=4,
word_embed_proj_dim=32,
output_projection=True)
return OPTLMModel(opt_config)
TEST_MODELS.register('opt_micro', opt_micro, micro_data_fn)

View File

@ -0,0 +1,26 @@
from collections import OrderedDict
from typing import Callable
class Registry(object):
def __init__(self) -> None:
super().__init__()
self._registry_dict = OrderedDict()
def register(self, name: str, model_fn: Callable, data_fn: Callable):
assert name not in self._registry_dict
model_tuple = (model_fn, data_fn)
self._registry_dict[name] = model_tuple
def get(self, name: str):
return self._registry_dict[name]
def __iter__(self):
return iter(self._registry_dict.items())
TEST_MODELS = Registry()
__all__ = [TEST_MODELS]

View File

@ -0,0 +1,23 @@
import torch
import torch.nn as nn
from torchvision.models import resnet18
from tests.test_elixir.utils.registry import TEST_MODELS
def resnet_data_fn():
return dict(x=torch.randn(4, 3, 32, 32))
class ResNetModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.r = resnet18()
def forward(self, x):
output = self.r(x)
return output.sum()
TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn)

View File

@ -0,0 +1,31 @@
import torch
import torch.nn as nn
from tests.test_elixir.utils.mlp import MlpModule
from tests.test_elixir.utils.registry import TEST_MODELS
def small_data_fn():
return dict(x=torch.randint(low=0, high=20, size=(4, 8)))
class SmallModel(nn.Module):
def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None:
super().__init__()
self.embed = nn.Embedding(num_embeddings, hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MlpModule(hidden_dim=hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False)
self.proj.weight = self.embed.weight
def forward(self, x):
x = self.embed(x)
x = x + self.norm1(self.mlp(x))
x = self.proj(self.norm2(x))
x = x.mean(dim=-2)
return x.sum()
TEST_MODELS.register('small', SmallModel, small_data_fn)