mirror of https://github.com/hpcaitech/ColossalAI
[elixir] add elixir and its unit tests (#3835)
* [elixir] add elixir * [elixir] add unit tests * remove useless code * fix python 3.8 issue * fix typo * add test skip * add docstrings * add docstrings * add readme * fix typopull/3864/head
parent
34966378e8
commit
206280408a
|
@ -0,0 +1,71 @@
|
|||
# Elixir (Gemini2.0)
|
||||
Elixir, also known as Gemini, is a technology designed to facilitate the training of large models on a small GPU cluster.
|
||||
Its goal is to eliminate data redundancy and leverage CPU memory to accommodate really large models.
|
||||
In addition, Elixir automatically profiles each training step prior to execution and selects the optimal configuration for the ratio of redundancy and the device for each parameter.
|
||||
This repository is used to benchmark the performance of Elixir.
|
||||
Elixir will be integrated into ColossalAI for usability.
|
||||
|
||||
## Environment
|
||||
|
||||
This version is a beta release, so the running environment is somewhat restrictive.
|
||||
We are only demonstrating our running environment here, as we have not yet tested its compatibility.
|
||||
We have set the CUDA version to `11.6` and the PyTorch version to `1.13.1+cu11.6`.
|
||||
|
||||
## Examples
|
||||
|
||||
Here is a simple example to wrap your model and optimizer for [fine-tuning](https://github.com/hpcaitech/Elixir/tree/main/example/fine-tune).
|
||||
|
||||
```python
|
||||
from elixir.search import minimum_waste_search
|
||||
from elixir.wrapper import ElixirModule, ElixirOptimizer
|
||||
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-8)
|
||||
|
||||
sr = minimum_waste_search(model, world_size)
|
||||
model = ElixirModule(model, sr, world_group)
|
||||
optimizer = ElixirOptimizer(model, optimizer)
|
||||
```
|
||||
|
||||
Here is an advanced example for performance, which is used in our [benchmark](https://github.com/hpcaitech/Elixir/blob/main/example/common/elx.py).
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from elixir.wrapper import ElixirModule, ElixirOptimizer
|
||||
|
||||
# get the world communication group
|
||||
global_group = dist.GroupMember.WORLD
|
||||
# get the communication world size
|
||||
global_size = dist.get_world_size()
|
||||
|
||||
# initialize the model in CPU
|
||||
model = get_model(model_name)
|
||||
# HybridAdam allows a part of parameters updated on CPU and a part updated on GPU
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
||||
sr = optimal_search(
|
||||
model,
|
||||
global_size,
|
||||
unified_dtype=torch.float16, # enable for FP16 training
|
||||
overlap=True, # enable for overlapping communications
|
||||
verbose=True, # print detailed processing information
|
||||
inp=data, # proivde an example input data in dictionary format
|
||||
step_fn=train_step # provide an example step function
|
||||
)
|
||||
model = ElixirModule(
|
||||
model,
|
||||
sr,
|
||||
global_group,
|
||||
prefetch=True, # prefetch chunks to overlap communications
|
||||
dtype=torch.float16, # use AMP
|
||||
use_fused_kernels=True # enable fused kernels in Apex
|
||||
)
|
||||
optimizer = ElixirOptimizer(
|
||||
model,
|
||||
optimizer,
|
||||
initial_scale=64, # loss scale used in AMP
|
||||
init_step=True # enable for the stability of training
|
||||
)
|
||||
```
|
|
@ -0,0 +1 @@
|
|||
from .wrapper import ElixirModule, ElixirOptimizer
|
|
@ -0,0 +1,2 @@
|
|||
from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
|
||||
from .fetcher import ChunkFetcher
|
|
@ -0,0 +1,4 @@
|
|||
from .chunk import Chunk
|
||||
from .group import ChunkGroup
|
||||
from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .states import TensorState
|
|
@ -0,0 +1,567 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.tensor import FakeTensor
|
||||
|
||||
from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .states import TensorState, ts_update_sanity_check
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
state: TensorState
|
||||
fake_data: FakeTensor
|
||||
offset: int
|
||||
end: int
|
||||
|
||||
|
||||
class Chunk:
|
||||
"""Chunk is a type of data structure to store tensors.
|
||||
It allows us to store a sequence of tensors into one continuous memory block.
|
||||
Moreover, Chunk manages the storage of tensors in a distributed way.
|
||||
Normally, a chunk is scattered across its process group.
|
||||
When a tensor in this chunk should be used later, the chunk can be gathered by access_chunk.
|
||||
When the training is done, the chunk can be scattered by reduce_chunk.
|
||||
|
||||
args:
|
||||
rcache: the memory pool to store replicated chunks
|
||||
chunk_size: the size of the chunk
|
||||
chunk_dtype: the dtype of the chunk
|
||||
process_group: the torch communication group of the chunk
|
||||
temp_device: the device to store the temporary chunk when initializing
|
||||
shard_device: the device to store the shard of the scattered chunk
|
||||
rcache_fused: whether this chunk is fused in rcache without eviction
|
||||
cpu_pin_memory: whether this chunk use cpu pin memory for its shard
|
||||
"""
|
||||
total_count = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rcache: MemoryPool,
|
||||
chunk_size: int,
|
||||
chunk_dtype: torch.dtype,
|
||||
process_group: ProcessGroup,
|
||||
temp_device: Optional[torch.device] = None,
|
||||
shard_device: Optional[torch.device] = None,
|
||||
rcache_fused: bool = False, # whether this chunk is used in ZeRO2
|
||||
cpu_pin_memory: bool = False # whether this chunk has a permanent copy in cpu
|
||||
) -> None:
|
||||
|
||||
self.chunk_id: int = Chunk.total_count
|
||||
Chunk.total_count += 1
|
||||
# set replicated cache pool
|
||||
self.rcache: MemoryPool = rcache
|
||||
|
||||
self.chunk_size: int = chunk_size
|
||||
self.chunk_dtype: torch.dtype = chunk_dtype
|
||||
self.utilized_size: int = 0
|
||||
|
||||
self.torch_pg: ProcessGroup = process_group
|
||||
self.pg_size: int = dist.get_world_size(self.torch_pg)
|
||||
self.pg_rank: int = dist.get_rank(self.torch_pg)
|
||||
|
||||
# the chunk size should be divisible by the dp degree
|
||||
assert chunk_size % self.pg_size == 0
|
||||
|
||||
self.shard_size: int = chunk_size // self.pg_size
|
||||
self.shard_begin: int = self.shard_size * self.pg_rank
|
||||
self.shard_end: int = self.shard_begin + self.shard_size
|
||||
self.valid_end: int = self.shard_size + 1 # set to an illegal number
|
||||
|
||||
# notice: release blocks reserved by Pytorch
|
||||
torch.cuda.empty_cache()
|
||||
# rcache block, the global replicated chunk in R cache
|
||||
self.rcb: Optional[TensorBlock] = None
|
||||
self.rcache_fused: bool = rcache_fused
|
||||
self._my_block = None
|
||||
self.is_replica: bool = True
|
||||
# allocate a private block for fused chunks
|
||||
if self.rcache_fused:
|
||||
self._my_block = rcache.get_private_block(chunk_size, chunk_dtype)
|
||||
|
||||
temp_device: torch.device = temp_device or gpu_device()
|
||||
# chunk_temp is a global chunk, which only exists during building the chunks.
|
||||
# keep all elements to zero
|
||||
self.chunk_temp: Optional[torch.Tensor] = None
|
||||
if rcache_fused:
|
||||
self.chunk_temp = self._my_block.payload
|
||||
torch.zero_(self.chunk_temp)
|
||||
else:
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=chunk_dtype, device=temp_device)
|
||||
|
||||
# configure the init device of the shard
|
||||
# no-offload default: fp16, fp32 -> CUDA
|
||||
# offload default: fp16, fp32 -> CPU
|
||||
shard_device: torch.device = shard_device or torch.device('cpu')
|
||||
pin_flag: bool = cpu_pin_memory and shard_device.type == 'cpu'
|
||||
# chunk.shard is a local chunk
|
||||
# it is desinged to exist permanently
|
||||
self.shard: torch.Tensor = torch.empty(self.shard_size,
|
||||
dtype=chunk_dtype,
|
||||
device=shard_device,
|
||||
pin_memory=pin_flag)
|
||||
|
||||
# calculate the memory occupation of the chunk and the shard
|
||||
self.chunk_memo: int = self.chunk_size * self.chunk_temp.element_size()
|
||||
self.shard_memo: int = self.chunk_memo // self.pg_size
|
||||
|
||||
# each tensor is associated with a TensorInfo to track its meta info
|
||||
# (state, shape, offset, end)
|
||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||
# the total number of tensors in the chunk
|
||||
self.num_tensors: int = 0
|
||||
|
||||
# Record the number of tensors in different states
|
||||
self.tensor_state_cnter: Dict[TensorState, int] = dict()
|
||||
for state in TensorState:
|
||||
self.tensor_state_cnter[state] = 0
|
||||
|
||||
# we introduce the paired chunk here
|
||||
# it refers to another chunk having the same parameters
|
||||
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
|
||||
self.paired_chunk = None
|
||||
# if this chunk is synchronized with the optimizer, the flag is True
|
||||
self.optim_sync_flag = True
|
||||
|
||||
# whether to record l2 norm for the gradient clipping calculation
|
||||
self.l2_norm_flag = False
|
||||
self.l2_norm = None
|
||||
# whether it overflows after the reduction
|
||||
self.overflow = False
|
||||
|
||||
@property
|
||||
def prepared_block(self):
|
||||
return self._my_block
|
||||
|
||||
@property
|
||||
def is_init(self):
|
||||
return self.chunk_temp is not None
|
||||
|
||||
@property
|
||||
def in_rcache(self):
|
||||
return self.rcb is not None
|
||||
|
||||
@property
|
||||
def shard_device(self):
|
||||
return self.shard.device
|
||||
|
||||
@property
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
cuda_memory = 0
|
||||
cpu_memory = 0
|
||||
|
||||
# this chunk is not closed
|
||||
if self.is_init:
|
||||
if self.chunk_temp.device.type == 'cuda':
|
||||
cuda_memory += self.chunk_memo
|
||||
else:
|
||||
cpu_memory += self.chunk_memo
|
||||
|
||||
# this chunk is on the rcache
|
||||
if self.in_rcache:
|
||||
cuda_memory += self.rcb.memo_occ
|
||||
|
||||
# calculate the occupation of the chunk shard
|
||||
if self.shard_device.type == 'cuda':
|
||||
cuda_memory += self.shard_memo
|
||||
elif self.shard_device.type == 'cpu':
|
||||
cpu_memory += self.shard_memo
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return dict(cuda=cuda_memory, cpu=cpu_memory)
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
if self.is_init:
|
||||
return self.chunk_temp
|
||||
|
||||
if self.in_rcache:
|
||||
return self.rcb.payload
|
||||
else:
|
||||
return self.shard
|
||||
|
||||
@property
|
||||
def shard_move_check(self) -> bool:
|
||||
return not self.in_rcache
|
||||
|
||||
def _not_compute_number(self):
|
||||
total = 0
|
||||
state_list = [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE]
|
||||
for state in state_list:
|
||||
total += self.tensor_state_cnter[state]
|
||||
return total
|
||||
|
||||
@property
|
||||
def scatter_check(self) -> bool:
|
||||
if self.rcache_fused:
|
||||
return False
|
||||
return self._not_compute_number() == self.num_tensors
|
||||
|
||||
@property
|
||||
def reduce_check(self):
|
||||
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
|
||||
|
||||
def set_overflow_flag(self, valid_tensor: torch.Tensor) -> None:
|
||||
assert not self.overflow
|
||||
self.overflow = torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
|
||||
def set_l2_norm(self, valid_tensor: torch.Tensor) -> None:
|
||||
assert self.l2_norm is None, 'you are calculating the l2 norm twice'
|
||||
chunk_l2_norm = valid_tensor.data.float().norm(2)
|
||||
self.l2_norm = chunk_l2_norm.item()**2
|
||||
|
||||
def append_tensor(self, tensor: torch.Tensor):
|
||||
# sanity check
|
||||
assert self.is_init
|
||||
assert tensor.dtype == self.chunk_dtype
|
||||
|
||||
new_utilized_size = self.utilized_size + tensor.numel()
|
||||
# raise exception when the chunk size is exceeded
|
||||
if new_utilized_size > self.chunk_size:
|
||||
raise ChunkFullError
|
||||
|
||||
self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten())
|
||||
tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||
fake_data = FakeTensor(tensor.data)
|
||||
|
||||
# record all the information about the tensor
|
||||
self.num_tensors += 1
|
||||
tensor_state = TensorState.HOLD
|
||||
self.tensor_state_cnter[tensor_state] += 1
|
||||
|
||||
self.tensors_info[tensor] = TensorInfo(state=tensor_state,
|
||||
fake_data=fake_data,
|
||||
offset=self.utilized_size,
|
||||
end=new_utilized_size)
|
||||
|
||||
self.utilized_size = new_utilized_size
|
||||
|
||||
def close_chunk(self):
|
||||
# sanity check
|
||||
assert self.is_init
|
||||
|
||||
# calculate the valid end for each shard
|
||||
if self.utilized_size <= self.shard_begin:
|
||||
self.valid_end = 0
|
||||
elif self.utilized_size < self.shard_end:
|
||||
self.valid_end = self.utilized_size - self.shard_begin
|
||||
|
||||
self.__remove_tensors_ptr()
|
||||
self.__update_shard(self.chunk_temp, self.shard)
|
||||
self.is_replica = False
|
||||
self.chunk_temp = None
|
||||
|
||||
def replicate(self):
|
||||
assert not self.is_replica
|
||||
|
||||
self.is_replica = True
|
||||
this_shard = self.shard if self.optim_sync_flag else self.__paired_shard()
|
||||
self.__update_replica(self.rcb.payload, this_shard)
|
||||
self.__update_tensors_ptr()
|
||||
|
||||
def scatter(self):
|
||||
assert not self.rcache_fused
|
||||
assert self.is_replica
|
||||
|
||||
self.__remove_tensors_ptr()
|
||||
if not self.optim_sync_flag:
|
||||
self.__update_shard(self.rcb.payload, self.shard)
|
||||
self.optim_sync_flag = True
|
||||
self.is_replica = False
|
||||
|
||||
def reduce(self, always_fp32: bool = False):
|
||||
assert self.is_replica
|
||||
|
||||
self.__remove_tensors_ptr()
|
||||
|
||||
if self.pg_size > 1:
|
||||
cast_to_fp32 = False
|
||||
if always_fp32 and self.chunk_dtype != torch.float:
|
||||
cast_to_fp32 = True
|
||||
# cast the payload to fp32
|
||||
reduce_buffer = self.rcb.payload.to(dtype=torch.float)
|
||||
else:
|
||||
# otherwise, use the same payload
|
||||
reduce_buffer = self.rcb.payload
|
||||
|
||||
# divide the reduce buffer by the size of the process group
|
||||
reduce_buffer /= self.pg_size
|
||||
# try to use inplace reduce scatter
|
||||
# notice: pytorch does not allow true inplace reduce scatter
|
||||
# because pytorch will allocate a continuous memory space for collective communications
|
||||
shard_buffer = reduce_buffer[self.shard_begin:self.shard_end]
|
||||
dist.reduce_scatter_tensor(shard_buffer, reduce_buffer, group=self.torch_pg)
|
||||
|
||||
# the result should be moved to payload for norm calculating
|
||||
if cast_to_fp32:
|
||||
calc_buffer = self.rcb.payload[self.shard_begin:self.shard_end]
|
||||
calc_buffer.copy_(shard_buffer)
|
||||
else:
|
||||
# if process group size equals to 1, do not communicate
|
||||
reduce_buffer = self.rcb.payload
|
||||
|
||||
self.__update_shard(reduce_buffer, self.shard)
|
||||
|
||||
self.is_replica = False
|
||||
|
||||
def access_chunk(self, block: Optional[TensorBlock] = None):
|
||||
# sanity check
|
||||
assert not self.is_init
|
||||
assert not self.is_replica
|
||||
|
||||
if self.rcache_fused:
|
||||
assert block is None
|
||||
self.rcb = self._my_block
|
||||
else:
|
||||
assert block in self.rcache.public_used_blocks
|
||||
assert self.rcb is None
|
||||
self.rcb = block
|
||||
|
||||
self.replicate()
|
||||
|
||||
def release_chunk(self) -> TensorBlock:
|
||||
# sanity check
|
||||
assert not self.is_init
|
||||
assert self.is_replica
|
||||
|
||||
if self.rcache_fused:
|
||||
raise RuntimeError
|
||||
|
||||
self.scatter()
|
||||
block = self.rcb
|
||||
self.rcb = None
|
||||
return block
|
||||
|
||||
def update_extra_reduce_info(self, block: Optional[TensorBlock]):
|
||||
if self.rcache_fused:
|
||||
assert block is None
|
||||
block = self._my_block
|
||||
else:
|
||||
assert block is not None
|
||||
|
||||
buffer = block.payload[self.shard_begin:self.shard_end]
|
||||
valid_tensor = buffer[:self.valid_end]
|
||||
self.set_overflow_flag(valid_tensor)
|
||||
if self.l2_norm_flag:
|
||||
self.set_l2_norm(valid_tensor)
|
||||
|
||||
def reduce_chunk(self, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]:
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA.
|
||||
"""
|
||||
# sanity check
|
||||
assert not self.is_init
|
||||
assert self.is_replica
|
||||
|
||||
self.reduce(always_fp32=always_fp32)
|
||||
self.__update_tensors_state(TensorState.HOLD)
|
||||
|
||||
# reset the rcb pointer
|
||||
block = self.rcb
|
||||
self.rcb = None
|
||||
if self.rcache_fused:
|
||||
block = None
|
||||
|
||||
if sync:
|
||||
self.update_extra_reduce_info(block)
|
||||
|
||||
return block
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
prev_state = self.tensors_info[tensor].state
|
||||
if prev_state == tensor_state:
|
||||
return
|
||||
if ts_update_sanity_check(prev_state, tensor_state):
|
||||
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
# sanity check
|
||||
assert self.is_replica
|
||||
|
||||
info = self.tensors_info[tensor]
|
||||
payload = self.rcb.payload
|
||||
payload[info.offset:info.end].copy_(data_slice.data.flatten())
|
||||
tensor.data = payload[info.offset:info.end].view(tensor.shape)
|
||||
|
||||
def init_pair(self, friend_chunk: 'Chunk') -> None:
|
||||
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
|
||||
self.paired_chunk = friend_chunk
|
||||
friend_chunk.paired_chunk = self
|
||||
else:
|
||||
assert self.paired_chunk is friend_chunk
|
||||
assert friend_chunk.paired_chunk is self
|
||||
|
||||
def optim_update(self) -> None:
|
||||
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.
|
||||
"""
|
||||
# sanity check
|
||||
assert self.paired_chunk is not None
|
||||
friend_chunk: Chunk = self.paired_chunk
|
||||
assert not friend_chunk.is_replica
|
||||
# gradient and optimizer should be on the same device
|
||||
assert self.shard_device.type == friend_chunk.shard_device.type
|
||||
|
||||
if self.shard_device.type == 'cuda':
|
||||
self.shard.copy_(friend_chunk.shard)
|
||||
self.optim_sync_flag = True
|
||||
elif self.shard_device.type == 'cpu':
|
||||
# optim_sync_flag is set to False
|
||||
# see shard_move function for more details
|
||||
self.optim_sync_flag = False
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
def get_cpu_copy(self, only_rank_0: bool = False) -> List[torch.Tensor]:
|
||||
assert not self.is_init
|
||||
|
||||
if self.is_replica:
|
||||
# use the payload directly when being replica
|
||||
temp_buffer = self.rcb.payload
|
||||
else:
|
||||
# otherwise, create a temporary buffer
|
||||
temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device())
|
||||
# cheat the assertion in __update_replica
|
||||
self.is_replica = True
|
||||
self.__update_replica(temp_buffer, self.shard)
|
||||
self.is_replica = False
|
||||
|
||||
cpu_copys = [None] * self.num_tensors
|
||||
if not only_rank_0 or self.pg_rank == 0:
|
||||
for i, (t, info) in enumerate(self.tensors_info.items()):
|
||||
t_copy = temp_buffer[info.offset:info.end].view(t.shape).cpu()
|
||||
cpu_copys[i] = t_copy
|
||||
# synchronize
|
||||
dist.barrier()
|
||||
return cpu_copys
|
||||
|
||||
def load_tensors(self, tensor_list: List[Optional[torch.Tensor]], only_rank_0: bool = False) -> bool:
|
||||
assert not self.is_replica
|
||||
assert not self.is_init
|
||||
temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device())
|
||||
# cheat the assertion in __update_replica
|
||||
self.is_replica = True
|
||||
self.__update_replica(temp_buffer, self.shard)
|
||||
self.is_replica = False
|
||||
|
||||
if not only_rank_0 or self.pg_rank == 0:
|
||||
for (_, c_info), load_tensor in zip(self.tensors_info.items(), tensor_list):
|
||||
if load_tensor is None:
|
||||
continue
|
||||
temp_buffer[c_info.offset:c_info.end].copy_(load_tensor.data.flatten())
|
||||
|
||||
# synchronize
|
||||
dist.barrier()
|
||||
|
||||
if only_rank_0:
|
||||
dist.broadcast(temp_buffer, src=0, group=self.torch_pg)
|
||||
|
||||
# cheat the assertion in __update_shard
|
||||
self.is_replica = True
|
||||
self.__update_shard(temp_buffer, self.shard)
|
||||
self.is_replica = False
|
||||
|
||||
def __update_replica(self, replica: torch.Tensor, shard: torch.Tensor):
|
||||
assert self.is_replica
|
||||
assert replica.numel() == self.chunk_size
|
||||
assert shard.numel() == self.shard_size
|
||||
|
||||
buffer = replica[self.shard_begin:self.shard_end]
|
||||
buffer.copy_(shard)
|
||||
dist.all_gather_into_tensor(replica, buffer, group=self.torch_pg)
|
||||
|
||||
def __update_shard(self, replica: torch.Tensor, shard: torch.Tensor):
|
||||
assert self.is_replica
|
||||
assert replica.numel() == self.chunk_size
|
||||
assert shard.numel() == self.shard_size
|
||||
|
||||
shard.copy_(replica[self.shard_begin:self.shard_end])
|
||||
|
||||
def __paired_shard(self):
|
||||
assert self.paired_chunk is not None, 'chunks should be paired before training'
|
||||
optim_chunk: Chunk = self.paired_chunk
|
||||
assert self.chunk_size == optim_chunk.chunk_size
|
||||
|
||||
# only be called when optimizer state is in CPU memory
|
||||
# the grad and param should be in the same device
|
||||
assert self.shard_device.type == 'cpu'
|
||||
return optim_chunk.shard.to(gpu_device())
|
||||
|
||||
def __remove_tensors_ptr(self) -> None:
|
||||
# sanity check
|
||||
# each tensor should point to its fake data before scatter
|
||||
assert self.is_replica
|
||||
for tensor, info in self.tensors_info.items():
|
||||
tensor.data = info.fake_data
|
||||
|
||||
def __update_tensors_ptr(self) -> None:
|
||||
# sanity check
|
||||
# the chunk should be replicated to get the correct pointer
|
||||
assert self.is_replica
|
||||
payload = self.rcb.payload
|
||||
for tensor, info in self.tensors_info.items():
|
||||
tensor.data = payload[info.offset:info.end].view(tensor.shape)
|
||||
|
||||
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
|
||||
self.tensor_state_cnter[tensor_info.state] -= 1
|
||||
tensor_info.state = next_state
|
||||
self.tensor_state_cnter[tensor_info.state] += 1
|
||||
|
||||
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||
for tensor_info in self.tensors_info.values():
|
||||
if prev_state is None or tensor_info.state == prev_state:
|
||||
self.__update_one_tensor_info(tensor_info, next_state)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.chunk_id
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
return self.chunk_id < other.chunk_id
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self.chunk_id == other.chunk_id
|
||||
|
||||
def __repr__(self, detailed: bool = True):
|
||||
if self.is_init:
|
||||
state = 'initialization'
|
||||
elif self.in_rcache:
|
||||
state = 'replicated'
|
||||
else:
|
||||
state = 'scattered'
|
||||
|
||||
output = [
|
||||
f'Chunk {self.chunk_id} details: state -> {state}\n',
|
||||
f' length: {self.chunk_size}, dtype: {self.chunk_dtype}, group_size: {self.pg_size}, tensors: {self.num_tensors}\n'
|
||||
f' utilized size: {self.utilized_size}, utilized percentage: {100 * (self.utilized_size / self.chunk_size):.0f}%\n'
|
||||
]
|
||||
|
||||
memory_info = self.memory_usage
|
||||
output.append(' memory usage: (cuda -> {}, cpu -> {})\n'.format(memory_info['cuda'], memory_info['cpu']))
|
||||
|
||||
def print_tensor(name, tensor, prefix=''):
|
||||
output.append(f'{prefix}{name}: (shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device})\n')
|
||||
|
||||
if self.is_init:
|
||||
print_tensor(name='temp', tensor=self.chunk_temp, prefix=' ')
|
||||
if self.in_rcache:
|
||||
print_tensor(name='block', tensor=self.rcb.payload, prefix=' ')
|
||||
if self.shard is not None:
|
||||
print_tensor(name='shard', tensor=self.shard, prefix=' ')
|
||||
|
||||
if detailed:
|
||||
output.append(' tensor state monitor:\n')
|
||||
for st in TensorState:
|
||||
output.append(' # of {}: {}\n'.format(st, self.tensor_state_cnter[st]))
|
||||
|
||||
return ''.join(output)
|
|
@ -0,0 +1,180 @@
|
|||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .chunk import Chunk
|
||||
from .memory_pool import MemoryPool, TensorBlock
|
||||
from .states import TensorState
|
||||
|
||||
|
||||
class ChunkGroup(object):
|
||||
"""ChunkGroup manages a group of chunks and their memory pool.
|
||||
Commonly, one model has one chunk group.
|
||||
It supports chunk allocation, chunk access, and chunk release.
|
||||
ChunkGroup is responsible for the memory management before its APIs.
|
||||
|
||||
args:
|
||||
rcache: A memory pool to instantiate chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, rcache: MemoryPool) -> None:
|
||||
super().__init__()
|
||||
self.rcache = rcache
|
||||
self.fused_chunks: Set[Chunk] = set()
|
||||
self.float_chunks: Set[Chunk] = set()
|
||||
self.ten_to_chunk: Dict[torch.Tensor, Chunk] = dict()
|
||||
|
||||
self.accessed_fused_chunks: Set[Chunk] = set()
|
||||
self.accessed_float_chunks: Set[Chunk] = set()
|
||||
|
||||
def __add_to_accset(self, chunk: Chunk):
|
||||
if chunk.rcache_fused:
|
||||
self.accessed_fused_chunks.add(chunk)
|
||||
else:
|
||||
self.accessed_float_chunks.add(chunk)
|
||||
|
||||
def __remove_from_accset(self, chunk: Chunk):
|
||||
if chunk.rcache_fused:
|
||||
self.accessed_fused_chunks.remove(chunk)
|
||||
else:
|
||||
self.accessed_float_chunks.remove(chunk)
|
||||
|
||||
def __check_new_float_chunk(self, size: int, dtype: torch.dtype):
|
||||
# if the public space is 0, there is no access operations
|
||||
if self.rcache.public_space == 0:
|
||||
return
|
||||
# otherwise, check its size and dtype
|
||||
assert size == self.rcache.public_block_size
|
||||
assert dtype == self.rcache.public_dtype
|
||||
|
||||
def inside_check(self, chunk: Chunk) -> None:
|
||||
"""Check whether the chunk is in this ChunkGroup"""
|
||||
if chunk.rcache_fused:
|
||||
assert chunk in self.fused_chunks
|
||||
else:
|
||||
assert chunk in self.float_chunks
|
||||
|
||||
def is_accessed(self, chunk: Chunk) -> bool:
|
||||
"""Chech whether the chunk is accessed."""
|
||||
# sanity check
|
||||
self.inside_check(chunk)
|
||||
|
||||
if chunk.rcache_fused:
|
||||
return (chunk in self.accessed_fused_chunks)
|
||||
else:
|
||||
return (chunk in self.accessed_float_chunks)
|
||||
|
||||
def open_chunk(self,
|
||||
chunk_size: int,
|
||||
chunk_dtype: torch.dtype,
|
||||
process_group: ProcessGroup,
|
||||
chunk_config: Optional[Dict] = None) -> Chunk:
|
||||
"""Open a chunk to store parameters."""
|
||||
if chunk_config is None:
|
||||
chunk_config = {}
|
||||
|
||||
chunk = Chunk(rcache=self.rcache,
|
||||
chunk_size=chunk_size,
|
||||
chunk_dtype=chunk_dtype,
|
||||
process_group=process_group,
|
||||
**chunk_config)
|
||||
# sanity check
|
||||
if not chunk.rcache_fused:
|
||||
self.__check_new_float_chunk(chunk_size, chunk_dtype)
|
||||
|
||||
return chunk
|
||||
|
||||
def close_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Close the chunk during the allocation."""
|
||||
chunk.close_chunk()
|
||||
# add the new chunk to the set of allocated chunks
|
||||
if chunk.rcache_fused:
|
||||
self.fused_chunks.add(chunk)
|
||||
else:
|
||||
self.float_chunks.add(chunk)
|
||||
# add the new chunk to the mapping
|
||||
for t in chunk.get_tensors():
|
||||
assert t not in self.ten_to_chunk
|
||||
self.ten_to_chunk[t] = chunk
|
||||
return True
|
||||
|
||||
def allocate_chunk(self,
|
||||
tensor_list: List[torch.Tensor],
|
||||
chunk_size: int,
|
||||
chunk_dtype: torch.dtype,
|
||||
process_group: ProcessGroup,
|
||||
chunk_config: Optional[Dict] = None) -> Chunk:
|
||||
"""Allocate a chunk for a list of parameters."""
|
||||
chunk = self.open_chunk(chunk_size, chunk_dtype, process_group, chunk_config)
|
||||
# append tensors
|
||||
for t in tensor_list:
|
||||
chunk.append_tensor(t)
|
||||
self.close_chunk(chunk)
|
||||
|
||||
return chunk
|
||||
|
||||
def tensors_to_chunks(self, tensor_list: List[torch.Tensor]) -> List[Chunk]:
|
||||
"""Get the chunks of a gevien list of tensors."""
|
||||
chunk_list = list()
|
||||
for tensor in tensor_list:
|
||||
chunk = self.ten_to_chunk.get(tensor)
|
||||
if chunk not in chunk_list:
|
||||
chunk_list.append(chunk)
|
||||
chunk_list.sort(key=lambda c: c.chunk_id)
|
||||
return chunk_list
|
||||
|
||||
def rcache_enough_check(self, chunk: Chunk) -> bool:
|
||||
"""Check whether the rcache has enough blocks to store the gathered chunk."""
|
||||
if chunk.rcache_fused:
|
||||
return True
|
||||
return self.rcache.public_free_cnt > 0
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Access a chunk into rCache."""
|
||||
self.inside_check(chunk)
|
||||
# if this chunk is accessed already, return False
|
||||
if self.is_accessed(chunk):
|
||||
return False
|
||||
|
||||
if chunk.rcache_fused:
|
||||
block = None
|
||||
else:
|
||||
block = self.rcache.get_public_block()
|
||||
chunk.access_chunk(block)
|
||||
self.__add_to_accset(chunk)
|
||||
return True
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Release a chunk from rCache."""
|
||||
self.inside_check(chunk)
|
||||
assert self.is_accessed(chunk)
|
||||
assert chunk.scatter_check
|
||||
|
||||
block = chunk.release_chunk()
|
||||
if block:
|
||||
self.rcache.free_public_block(block)
|
||||
self.__remove_from_accset(chunk)
|
||||
return True
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]:
|
||||
"""Reduce and scatter a gradient chunk from rCache."""
|
||||
self.inside_check(chunk)
|
||||
assert self.is_accessed(chunk)
|
||||
assert chunk.reduce_check
|
||||
|
||||
block = chunk.reduce_chunk(always_fp32=always_fp32, sync=sync)
|
||||
if block and sync:
|
||||
# if synchronized, free the block into rcache
|
||||
self.rcache.free_public_block(block)
|
||||
block = None
|
||||
|
||||
self.__remove_from_accset(chunk)
|
||||
|
||||
return block
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, state: TensorState):
|
||||
"""Transform the state of a tensor."""
|
||||
chunk = self.ten_to_chunk.get(tensor)
|
||||
chunk.tensor_trans_state(tensor, state)
|
|
@ -0,0 +1,172 @@
|
|||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import Iterable, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
|
||||
|
||||
class BlockRequire(NamedTuple):
|
||||
numel: int
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
class TensorBlock(ABC):
|
||||
"""TensorBlock is the memory unit of memory pool.
|
||||
It is a continuous memory block used to store tensors.
|
||||
Each chunk needs a corresponding TensorBlock to store its data during training.
|
||||
|
||||
args:
|
||||
numel: the number of elements in the block
|
||||
dtype: the data type of the block
|
||||
device_type: the device type of the block
|
||||
"""
|
||||
total_count: int = 0
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
self.block_id = TensorBlock.total_count
|
||||
TensorBlock.total_count += 1
|
||||
|
||||
self.device_type = device_type
|
||||
self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type)
|
||||
self.memo_occ: int = self.payload.numel() * self.payload.element_size()
|
||||
|
||||
@property
|
||||
def numel(self):
|
||||
return self.payload.numel()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.payload.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.payload.device
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.block_id
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self.block_id == other.block_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})'
|
||||
|
||||
|
||||
class PublicBlock(TensorBlock):
|
||||
"""Public blocks have the same length.
|
||||
Chunks of the same length can share the same public block.
|
||||
"""
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
super().__init__(numel, dtype, device_type)
|
||||
self.block_type = 'public'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'PublicBlock{super().__repr__()}'
|
||||
|
||||
|
||||
class PrivateBlock(TensorBlock):
|
||||
"""Private blocks may have different lengths.
|
||||
Each private chunk should use its own private block.
|
||||
"""
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
super().__init__(numel, dtype, device_type)
|
||||
self.block_type = 'private'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'PrivateBlock{super().__repr__()}'
|
||||
|
||||
|
||||
class MemoryPool(object):
|
||||
"""A memory pool consists of public blocks and private blocks.
|
||||
rCache uses memory pool to manage memory bolcks.
|
||||
Users should allocate memory blocks before using it.
|
||||
|
||||
args:
|
||||
device_type: the device type of the memory pool
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
self.device_type: str = device_type
|
||||
|
||||
self.public_space: int = 0
|
||||
self.public_block_size: int = 0
|
||||
self.public_dtype: torch.dtype = None
|
||||
|
||||
self.public_free_blocks: list = None
|
||||
self.public_used_blocks: set = None
|
||||
|
||||
self.public_free_cnt: int = 0
|
||||
self.public_used_cnt: int = 0
|
||||
|
||||
self.private_space: int = 0
|
||||
self.private_blocks: list = None
|
||||
self.private_lookup_dict: dict[BlockRequire, list] = None
|
||||
|
||||
self.__allocate_flag = False
|
||||
|
||||
def allocate(self,
|
||||
public_dtype: torch.dtype = torch.float,
|
||||
public_block_size: int = 1024,
|
||||
public_block_number: int = 0,
|
||||
private_block_list: Iterable[BlockRequire] = ()):
|
||||
assert self.__allocate_flag is False
|
||||
assert public_block_number >= 0
|
||||
|
||||
self.public_free_blocks = list()
|
||||
self.public_used_blocks = set()
|
||||
for _ in range(public_block_number):
|
||||
block = PublicBlock(public_block_size, public_dtype, self.device_type)
|
||||
self.public_free_blocks.append(block)
|
||||
|
||||
if public_block_number <= 0:
|
||||
self.public_space = 0
|
||||
else:
|
||||
self.public_space = self.public_free_blocks[0].memo_occ * public_block_number
|
||||
self.public_block_size = public_block_size
|
||||
self.public_dtype = public_dtype
|
||||
|
||||
self.public_free_cnt = public_block_number
|
||||
self.public_used_cnt = 0
|
||||
|
||||
self.private_space = 0
|
||||
self.private_blocks = list()
|
||||
self.private_lookup_dict = defaultdict(list)
|
||||
|
||||
for require in private_block_list:
|
||||
block = PrivateBlock(require.numel, require.dtype, self.device_type)
|
||||
self.private_space += block.memo_occ
|
||||
self.private_blocks.append(block)
|
||||
self.private_lookup_dict[require].append(block)
|
||||
|
||||
self.__allocate_flag = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})'
|
||||
|
||||
def get_private_block(self, numel: int, dtype: torch.dtype):
|
||||
block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype))
|
||||
return block_list.pop()
|
||||
|
||||
def get_public_block(self):
|
||||
self.public_free_cnt -= 1
|
||||
self.public_used_cnt += 1
|
||||
|
||||
block = self.public_free_blocks.pop()
|
||||
self.public_used_blocks.add(block)
|
||||
|
||||
return block
|
||||
|
||||
def free_public_block(self, block: TensorBlock):
|
||||
assert isinstance(block, PublicBlock)
|
||||
assert block in self.public_used_blocks
|
||||
|
||||
self.public_free_cnt += 1
|
||||
self.public_used_cnt -= 1
|
||||
|
||||
self.public_used_blocks.remove(block)
|
||||
self.public_free_blocks.append(block)
|
||||
|
||||
return block
|
|
@ -0,0 +1,25 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
# expected: free -> hold -> compute -> hold ->
|
||||
# -> compute -> hold_after_bwd -> ready_for_reduce
|
||||
legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
|
||||
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
|
||||
|
||||
|
||||
def ts_update_sanity_check(old_state, new_state) -> bool:
|
||||
if (old_state, new_state) not in legal_ts_update_list:
|
||||
raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}')
|
||||
return True
|
|
@ -0,0 +1,210 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .core import Chunk, ChunkGroup, TensorBlock, TensorState
|
||||
from .scheduler import ChunkScheduler
|
||||
|
||||
|
||||
class ChunkFetcher(object):
|
||||
"""ChunkFetcher is responsible for fetching and reducing chunks during training.
|
||||
Any operations on chunks should be done through ChunkFetcher.
|
||||
|
||||
args:
|
||||
scheduler: A ChunkScheduler to schedule evictable chunks.
|
||||
group: A ChunkGroup to manage chunks.
|
||||
overlap: Whether to overlap communications.
|
||||
reduce_always_fp32: Whether to reduce gradients in FP32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scheduler: ChunkScheduler,
|
||||
group: ChunkGroup,
|
||||
overlap: bool = False,
|
||||
reduce_always_fp32: bool = False) -> None:
|
||||
|
||||
self.scheduler: ChunkScheduler = scheduler
|
||||
self.group: ChunkGroup = group
|
||||
self.reduce_always_fp32 = reduce_always_fp32
|
||||
self.current_step = -1
|
||||
|
||||
self.overlap_flag = overlap
|
||||
self.main_stream = torch.cuda.current_stream()
|
||||
|
||||
self.predict_next_chunk: Optional[Chunk] = None
|
||||
self.is_fetching: bool = False
|
||||
self.prefetch_stream = torch.cuda.Stream()
|
||||
|
||||
self.reduced_chunk: Optional[Chunk] = None
|
||||
self.reduced_block: Optional[TensorBlock] = None
|
||||
self.reduce_stream = torch.cuda.Stream()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the fetcher to the initial state.
|
||||
Users should call this function before training."""
|
||||
self.scheduler.reset()
|
||||
self.current_step = -1
|
||||
|
||||
def clear(self):
|
||||
"""Clear the fetcher.
|
||||
Users should call this function after training."""
|
||||
if self.overlap_flag:
|
||||
torch.cuda.synchronize()
|
||||
self.predict_next_chunk = None
|
||||
self.is_fetching = False
|
||||
if self.reduced_chunk is not None:
|
||||
self.reduce_call_back()
|
||||
self.reduced_chunk = None
|
||||
self.reduced_block = None
|
||||
|
||||
self.scheduler.clear()
|
||||
|
||||
def trans_to_compute(self, tensors: List[torch.Tensor]):
|
||||
"""Transform tensors to COMPUTE state.
|
||||
This function should be called before the compute operators."""
|
||||
# update tensor states
|
||||
for t in tensors:
|
||||
self.group.tensor_trans_state(t, TensorState.COMPUTE)
|
||||
# chunk operations
|
||||
chunks = self.group.tensors_to_chunks(tensors)
|
||||
for chunk in chunks:
|
||||
self.scheduler.remove(chunk)
|
||||
return chunks
|
||||
|
||||
def trans_to_hold(self, tensors: List[torch.Tensor], phase: str):
|
||||
"""Transform tensors to HOLD state.
|
||||
This function should be called after the compute operators."""
|
||||
assert phase in ('f', 'b')
|
||||
next_state = TensorState.HOLD if phase == 'f' else TensorState.HOLD_AFTER_BWD
|
||||
# update tensor states
|
||||
for t in tensors:
|
||||
self.group.tensor_trans_state(t, next_state)
|
||||
# chunk operations
|
||||
chunks = self.group.tensors_to_chunks(tensors)
|
||||
for chunk in chunks:
|
||||
if chunk.scatter_check:
|
||||
self.scheduler.add(chunk)
|
||||
|
||||
def get_one_chunk(self, tensor: torch.Tensor) -> Chunk:
|
||||
"""Get the chunk of the given tensor."""
|
||||
return self.group.ten_to_chunk.get(tensor)
|
||||
|
||||
def get_chunks(self, tensors: List[torch.Tensor]) -> List[Chunk]:
|
||||
"""Get the chunks of the given tensors."""
|
||||
return self.group.tensors_to_chunks(tensors)
|
||||
|
||||
def is_in_fused(self, tensor: torch.Tensor):
|
||||
"""Check whether the given tensor is in a fused chunk."""
|
||||
chunk = self.get_one_chunk(tensor)
|
||||
return chunk.rcache_fused
|
||||
|
||||
def filter_chunks(self, chunks: List[Chunk]):
|
||||
"""Filter the accessed chunks, since they are already on the rCache."""
|
||||
return list(filter(lambda c: not self.group.is_accessed(c), chunks))
|
||||
|
||||
def fetch_chunks(self, chunks: List[Chunk]):
|
||||
"""Fetch chunks needed for this compute operator.
|
||||
The chunks should be in the COMPUTE state first."""
|
||||
# make step + 1
|
||||
self.step()
|
||||
|
||||
predict_hit = False
|
||||
# try to prefetch the next chunk
|
||||
if self.predict_next_chunk is not None and self.predict_next_chunk in chunks:
|
||||
if self.is_fetching:
|
||||
# prefetch hit, wait async prefetch
|
||||
self.main_stream.wait_stream(self.prefetch_stream)
|
||||
# clear prefetch information
|
||||
self.predict_next_chunk = None
|
||||
self.is_fetching = False
|
||||
|
||||
predict_hit = True
|
||||
# filter accessed chunks
|
||||
scattered = self.filter_chunks(chunks)
|
||||
# sanity check: upload should wait for prefetch
|
||||
if self.predict_next_chunk is not None:
|
||||
assert len(scattered) == 0
|
||||
# all chunks are on the rcache
|
||||
if len(scattered) == 0:
|
||||
# prefetch if there is a hit above
|
||||
if predict_hit:
|
||||
self.prefetch(chunks)
|
||||
return
|
||||
|
||||
for chunk in scattered:
|
||||
# if the rcache is not enough, just release a chunk
|
||||
if not self.group.rcache_enough_check(chunk):
|
||||
maybe_chunk = self.scheduler.top()
|
||||
# print(f'Evicting {chunk.chunk_id} -> {maybe_chunk.chunk_id}')
|
||||
if maybe_chunk is None:
|
||||
raise RuntimeError('R cache is not enough. Try to allocate more.')
|
||||
self.scheduler.remove(maybe_chunk)
|
||||
self.group.release_chunk(maybe_chunk)
|
||||
|
||||
# print('Accessing', chunk.chunk_id)
|
||||
self.group.access_chunk(chunk)
|
||||
|
||||
if self.overlap_flag:
|
||||
assert self.predict_next_chunk is None
|
||||
self.prefetch(chunks)
|
||||
|
||||
def reduce_call_back(self):
|
||||
self.reduced_chunk.update_extra_reduce_info(self.reduced_block)
|
||||
if self.reduced_block is not None:
|
||||
self.group.rcache.free_public_block(self.reduced_block)
|
||||
|
||||
def reduce_chunk(self, chunk: Chunk):
|
||||
"""Reduce and scatter the given gradient chunk."""
|
||||
if not chunk.reduce_check:
|
||||
return False
|
||||
|
||||
self.scheduler.remove(chunk)
|
||||
|
||||
if not self.overlap_flag:
|
||||
# reduce the chunk if not overlapped
|
||||
self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=True)
|
||||
else:
|
||||
# wait main stream for its computation
|
||||
self.reduce_stream.wait_stream(self.main_stream)
|
||||
# main stream should wait reduce stream
|
||||
# if there is a block recycle
|
||||
if self.reduced_chunk is not None:
|
||||
self.main_stream.wait_stream(self.reduce_stream)
|
||||
self.reduce_call_back()
|
||||
|
||||
with torch.cuda.stream(self.reduce_stream):
|
||||
self.reduced_chunk = chunk
|
||||
self.reduced_block = self.group.reduce_chunk(chunk, always_fp32=self.reduce_always_fp32, sync=False)
|
||||
|
||||
def prefetch(self, chunks: List[Chunk]):
|
||||
"""Prefetch the next used chunk."""
|
||||
next_chunk = self.scheduler.get_next_chunk(chunks)
|
||||
self.predict_next_chunk = next_chunk
|
||||
|
||||
# return if there is no next scattered chunk
|
||||
if next_chunk is None or self.group.is_accessed(next_chunk):
|
||||
return
|
||||
|
||||
evict_chunk = None
|
||||
if not self.group.rcache_enough_check(next_chunk):
|
||||
maybe_chunk = self.scheduler.top()
|
||||
# if there is no chunk can be evicted, just return
|
||||
if maybe_chunk is None:
|
||||
return
|
||||
# otherwise, release this chunk
|
||||
self.scheduler.remove(maybe_chunk)
|
||||
evict_chunk = maybe_chunk
|
||||
|
||||
with torch.cuda.stream(self.prefetch_stream):
|
||||
# wait main stream
|
||||
self.prefetch_stream.wait_stream(self.main_stream)
|
||||
self.is_fetching = True
|
||||
|
||||
if evict_chunk is not None:
|
||||
self.group.release_chunk(evict_chunk)
|
||||
self.group.access_chunk(next_chunk)
|
||||
|
||||
def step(self):
|
||||
"""Update the scheduler."""
|
||||
self.scheduler.step()
|
||||
self.current_step += 1
|
|
@ -0,0 +1,3 @@
|
|||
from .base import ChunkScheduler
|
||||
from .fifo import FIFOScheduler
|
||||
from .prefetch import PrefetchScheduler
|
|
@ -0,0 +1,50 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.elixir.chunk.core import Chunk
|
||||
|
||||
|
||||
class ChunkScheduler(ABC):
|
||||
"""The base class of all chunk schedulers.
|
||||
A chunk scherduler stores all releasable chunks.
|
||||
It provides APIs to add, remove, display releasable chunks.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.releasable_set: Optional[set] = None
|
||||
self.current_step = -1
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
self.releasable_set = set()
|
||||
self.current_step = -1
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
# asure the set is empty now
|
||||
assert not bool(self.releasable_set)
|
||||
|
||||
@abstractmethod
|
||||
def top(self) -> Optional[Chunk]:
|
||||
# return None if the releasable set is empty
|
||||
if not self.releasable_set:
|
||||
return False
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def add(self, chunk: Chunk) -> bool:
|
||||
if chunk in self.releasable_set:
|
||||
return False
|
||||
self.releasable_set.add(chunk)
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, chunk: Chunk) -> bool:
|
||||
if chunk not in self.releasable_set:
|
||||
return False
|
||||
self.releasable_set.remove(chunk)
|
||||
return True
|
||||
|
||||
def step(self, *args, **kwags):
|
||||
self.current_step += 1
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Optional
|
||||
|
||||
from .base import Chunk, ChunkScheduler
|
||||
|
||||
|
||||
class FIFOScheduler(ChunkScheduler):
|
||||
"""The FIFO chunk scheduler.
|
||||
It stores all releasable chunks in a FIFO queue.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fifo_dict: Optional[dict] = None
|
||||
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self.fifo_dict = dict()
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.fifo_dict = None
|
||||
|
||||
def top(self) -> Optional[Chunk]:
|
||||
if not super().top():
|
||||
return None
|
||||
dict_iter = iter(self.fifo_dict)
|
||||
ret = next(dict_iter)
|
||||
return ret
|
||||
|
||||
def add(self, chunk: Chunk) -> bool:
|
||||
if not super().add(chunk):
|
||||
return False
|
||||
self.fifo_dict[chunk] = True
|
||||
return True
|
||||
|
||||
def remove(self, chunk: Chunk) -> bool:
|
||||
if not super().remove(chunk):
|
||||
return False
|
||||
self.fifo_dict.pop(chunk)
|
||||
return True
|
|
@ -0,0 +1,85 @@
|
|||
from collections import defaultdict
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
from sortedcontainers import SortedSet
|
||||
|
||||
from .base import Chunk, ChunkScheduler
|
||||
|
||||
|
||||
class PrefetchScheduler(ChunkScheduler):
|
||||
"""The prefetch chunk scheduler.
|
||||
Its top functions gives the furthest used chunk.
|
||||
"""
|
||||
|
||||
def __init__(self, chunk_called_per_step: List[Iterable[Chunk]]) -> None:
|
||||
super().__init__()
|
||||
self.chunk_mapping = None
|
||||
self.evict_set = None
|
||||
self.search_step = -1
|
||||
|
||||
self.chunks_per_step = chunk_called_per_step
|
||||
self.total_steps = len(chunk_called_per_step)
|
||||
self.next_step_dict = defaultdict(list)
|
||||
# initialize the next_step dictionary
|
||||
for i, c_list in enumerate(chunk_called_per_step):
|
||||
for c in c_list:
|
||||
self.next_step_dict[c].append(i)
|
||||
|
||||
def _get_next_step(self, chunk: Chunk):
|
||||
step_list = self.next_step_dict[chunk]
|
||||
for i in step_list:
|
||||
if i > self.current_step:
|
||||
return i
|
||||
return self.total_steps
|
||||
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self.chunk_mapping = dict()
|
||||
self.evict_set = SortedSet()
|
||||
self.search_step = -1
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
if torch.is_grad_enabled():
|
||||
assert self.current_step == self.total_steps - 1
|
||||
self.chunk_mapping = None
|
||||
self.evict_set = None
|
||||
self.search_step = -1
|
||||
|
||||
def top(self) -> Optional[Chunk]:
|
||||
if not super().top():
|
||||
return None
|
||||
next_step, chunk = self.evict_set[-1]
|
||||
return chunk
|
||||
|
||||
def add(self, chunk: Chunk) -> bool:
|
||||
if not super().add(chunk):
|
||||
return False
|
||||
value = (self._get_next_step(chunk), chunk)
|
||||
self.chunk_mapping[chunk] = value
|
||||
self.evict_set.add(value)
|
||||
return True
|
||||
|
||||
def remove(self, chunk: Chunk) -> bool:
|
||||
if not super().remove(chunk):
|
||||
return False
|
||||
value = self.chunk_mapping[chunk]
|
||||
self.evict_set.remove(value)
|
||||
self.chunk_mapping.pop(chunk)
|
||||
return True
|
||||
|
||||
def step(self, *args, **kwags):
|
||||
super().step(*args, **kwags)
|
||||
if self.current_step >= self.total_steps:
|
||||
raise RuntimeError('exceed simulated steps, please modify your profiling `step_fn`')
|
||||
|
||||
def get_next_chunk(self, chunks: List[Chunk]):
|
||||
self.search_step = max(self.search_step, self.current_step + 1)
|
||||
while self.search_step < self.total_steps:
|
||||
c_list = self.chunks_per_step[self.search_step]
|
||||
for c in c_list:
|
||||
if c not in chunks:
|
||||
return c
|
||||
self.search_step += 1
|
||||
return None
|
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
|
||||
tensor_creation_methods = dict(tensor=torch.tensor,
|
||||
sparse_coo_tensor=torch.sparse_coo_tensor,
|
||||
asarray=torch.asarray,
|
||||
as_tensor=torch.as_tensor,
|
||||
as_strided=torch.as_strided,
|
||||
from_numpy=torch.from_numpy,
|
||||
from_dlpack=torch.from_dlpack,
|
||||
frombuffer=torch.frombuffer,
|
||||
zeros=torch.zeros,
|
||||
zeros_like=torch.zeros_like,
|
||||
ones=torch.ones,
|
||||
ones_like=torch.ones_like,
|
||||
arange=torch.arange,
|
||||
range=torch.range,
|
||||
linspace=torch.linspace,
|
||||
logspace=torch.logspace,
|
||||
eye=torch.eye,
|
||||
empty=torch.empty,
|
||||
empty_like=torch.empty_like,
|
||||
empty_strided=torch.empty_strided,
|
||||
full=torch.full,
|
||||
full_like=torch.full_like,
|
||||
quantize_per_tensor=torch.quantize_per_tensor,
|
||||
quantize_per_channel=torch.quantize_per_channel,
|
||||
dequantize=torch.dequantize,
|
||||
complex=torch.complex,
|
||||
polar=torch.polar,
|
||||
heaviside=torch.heaviside)
|
||||
|
||||
from .meta_ctx import MetaContext
|
|
@ -0,0 +1,34 @@
|
|||
import torch
|
||||
|
||||
from colossalai.elixir.ctx import tensor_creation_methods
|
||||
|
||||
|
||||
class MetaContext(object):
|
||||
"""A context manager that wraps all tensor creation methods in torch.
|
||||
By default, all tensors will be created in meta.
|
||||
|
||||
args:
|
||||
device_type: The device type of the tensors to be created.
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str = 'meta') -> None:
|
||||
super().__init__()
|
||||
self.device_type = device_type
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def meta_wrap(func):
|
||||
|
||||
def wrapped_func(*args, **kwargs):
|
||||
kwargs['device'] = self.device_type
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
for name, method in tensor_creation_methods.items():
|
||||
setattr(torch, name, meta_wrap(method))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for name, method in tensor_creation_methods.items():
|
||||
setattr(torch, name, method)
|
|
@ -0,0 +1,28 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
from torch.cuda._utils import _get_device_index
|
||||
|
||||
elixir_cuda_fraction = dict()
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def gpu_device():
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
|
||||
def set_memory_fraction(fraction, device=None):
|
||||
torch.cuda.set_per_process_memory_fraction(fraction, device)
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
device = _get_device_index(device)
|
||||
elixir_cuda_fraction[device] = fraction
|
||||
|
||||
|
||||
def get_allowed_memory(device=None):
|
||||
total_memory = torch.cuda.get_device_properties(device).total_memory
|
||||
if device is None:
|
||||
device = torch.cuda.current_device()
|
||||
device = _get_device_index(device)
|
||||
fraction = elixir_cuda_fraction.get(device, 1.0)
|
||||
return int(fraction * total_memory)
|
|
@ -0,0 +1,2 @@
|
|||
from .parameter import HookParam
|
||||
from .storage import BufferStore
|
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
|
||||
from colossalai.elixir.chunk import ChunkFetcher
|
||||
|
||||
from .storage import BufferStore
|
||||
|
||||
|
||||
def prefwd_postbwd_function(fetcher: ChunkFetcher, store: BufferStore):
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
with torch._C.DisableTorchFunction():
|
||||
ctx.params = params
|
||||
chunks = fetcher.trans_to_compute(params)
|
||||
fetcher.fetch_chunks(chunks)
|
||||
|
||||
offset = 0
|
||||
for p in ctx.params:
|
||||
if not fetcher.is_in_fused(p):
|
||||
# we should add parameters to buffer
|
||||
# because their blocks may be changed
|
||||
offset = store.insert(p, offset)
|
||||
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
with torch._C.DisableTorchFunction():
|
||||
fetcher.trans_to_hold(ctx.params, phase='b')
|
||||
|
||||
for p in ctx.params:
|
||||
if not fetcher.is_in_fused(p):
|
||||
store.erase(p)
|
||||
|
||||
return (None, *grads)
|
||||
|
||||
return PreFwdPostBwd.apply
|
||||
|
||||
|
||||
def postfwd_prebwd_function(fetcher: ChunkFetcher, store: BufferStore):
|
||||
|
||||
class PostFwdPreBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
with torch._C.DisableTorchFunction():
|
||||
ctx.params = params
|
||||
|
||||
fetcher.trans_to_hold(ctx.params, phase='f')
|
||||
for p in ctx.params:
|
||||
if not fetcher.is_in_fused(p):
|
||||
store.erase(p)
|
||||
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunks = fetcher.trans_to_compute(ctx.params)
|
||||
fetcher.fetch_chunks(chunks)
|
||||
|
||||
offset = 0
|
||||
for p in ctx.params:
|
||||
if not fetcher.is_in_fused(p):
|
||||
# we should add parameters to buffer
|
||||
# because their blocks may be changed
|
||||
offset = store.insert(p, offset)
|
||||
|
||||
return (None, *grads)
|
||||
|
||||
return PostFwdPreBwd.apply
|
|
@ -0,0 +1,102 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.elixir.chunk import ChunkFetcher
|
||||
from colossalai.elixir.kernels import fused_torch_functions
|
||||
from colossalai.elixir.tensor import OutplaceTensor, is_no_hook_op, to_outplace_tensor
|
||||
|
||||
from .functions import postfwd_prebwd_function, prefwd_postbwd_function
|
||||
from .storage import BufferStore
|
||||
|
||||
|
||||
class HookParam(OutplaceTensor, nn.Parameter):
|
||||
"""HookParam is a special type of tensor that is used to triggered hooks on parameters.
|
||||
HookParam adds chunk fetching before torch functions.
|
||||
"""
|
||||
pre_fwd_func = None
|
||||
post_fwd_func = None
|
||||
use_fused_kernel = False
|
||||
|
||||
@staticmethod
|
||||
def attach_fetcher(fetcher: ChunkFetcher, store: BufferStore):
|
||||
HookParam.pre_fwd_func = prefwd_postbwd_function(fetcher, store)
|
||||
HookParam.post_fwd_func = postfwd_prebwd_function(fetcher, store)
|
||||
|
||||
@staticmethod
|
||||
def release_fetcher():
|
||||
HookParam.pre_fwd_func = None
|
||||
HookParam.post_fwd_func = None
|
||||
|
||||
@staticmethod
|
||||
def enable_fused_kernel():
|
||||
HookParam.use_fused_kernel = True
|
||||
|
||||
@staticmethod
|
||||
def disable_fused_kernel():
|
||||
HookParam.use_fused_kernel = False
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if is_no_hook_op(func):
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
params_to_index = OrderedDict()
|
||||
params_index = 0
|
||||
|
||||
def append_param(x):
|
||||
nonlocal params_index
|
||||
if isinstance(x, HookParam):
|
||||
params_to_index[x] = params_index
|
||||
params_index += 1
|
||||
|
||||
tree_map(append_param, args)
|
||||
tree_map(append_param, kwargs)
|
||||
|
||||
params = tuple(params_to_index.keys())
|
||||
new_params = HookParam.pre_fwd_func(params, *params)
|
||||
|
||||
def replace_param(x):
|
||||
if isinstance(x, HookParam):
|
||||
return new_params[params_to_index[x]]
|
||||
return x
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
if HookParam.use_fused_kernel and func in fused_torch_functions:
|
||||
func = fused_torch_functions.get(func)
|
||||
ret = func(*tree_map(replace_param, args), **tree_map(replace_param, kwargs))
|
||||
if not isinstance(ret, tuple):
|
||||
ret = (ret,)
|
||||
|
||||
ptr_set = set()
|
||||
for p in new_params:
|
||||
ptr_set.add(p.data_ptr())
|
||||
|
||||
def clone_inplace_tensor(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
start_point = x.data_ptr() - x.element_size() * x.storage_offset()
|
||||
if start_point in ptr_set:
|
||||
return x.clone()
|
||||
return x
|
||||
|
||||
ret = tree_map(clone_inplace_tensor, ret)
|
||||
ret = HookParam.post_fwd_func(params, *ret)
|
||||
|
||||
def convert(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = to_outplace_tensor(t)
|
||||
return t
|
||||
|
||||
ret = tree_map(convert, ret)
|
||||
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return ret
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
|
||||
|
||||
class BufferStore(object):
|
||||
"""A place to store parameters temporarily when computing.
|
||||
Parameters should be stored into the buffer before computaions.
|
||||
|
||||
args:
|
||||
buffer_size: The size of the buffer.
|
||||
buffer_dtype: The dtype of the buffer.
|
||||
device_str: The device to store the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size: torch.Tensor, buffer_dtype: torch.dtype, device_str: str = 'cuda') -> None:
|
||||
super().__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.buffer_dtype = buffer_dtype
|
||||
self.buffer: torch.Tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device_str)
|
||||
self.buffer_occ = buffer_size * self.buffer.element_size()
|
||||
self.record_dict = dict()
|
||||
|
||||
def zeros(self):
|
||||
torch.zero_(self.buffer)
|
||||
|
||||
def insert(self, t: torch.Tensor, offset: int) -> int:
|
||||
assert t not in self.record_dict
|
||||
end = offset + t.numel()
|
||||
assert end <= self.buffer_size, f'buffer size is {self.buffer_size} but needs {end}'
|
||||
|
||||
new_data = self.buffer[offset:end].view(t.shape)
|
||||
new_data.copy_(t.data)
|
||||
|
||||
self.record_dict[t] = t.data
|
||||
t.data = new_data
|
||||
|
||||
return end
|
||||
|
||||
def erase(self, t: torch.Tensor):
|
||||
assert t in self.record_dict
|
||||
|
||||
new_data = self.record_dict.pop(t)
|
||||
t.data = new_data
|
||||
|
||||
return
|
||||
|
||||
def empty_like(self, t: torch.Tensor):
|
||||
return self.buffer[:t.numel()].view(t.shape)
|
||||
|
||||
def empty_1d(self, size: int):
|
||||
return self.buffer[:size]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Buffer(size={self.buffer_size}, dtype={self.buffer_dtype}, memo_occ={_format_memory(self.buffer_occ)})'
|
|
@ -0,0 +1,17 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
fused_torch_functions = {F.layer_norm: F.layer_norm}
|
||||
|
||||
|
||||
def register_fused_layer_norm():
|
||||
try:
|
||||
from .layernorm import ln_func
|
||||
fused_torch_functions[F.layer_norm] = ln_func
|
||||
print('Register fused layer norm successfully from apex.')
|
||||
except:
|
||||
print('Cannot import fused layer norm, please install apex from source.')
|
||||
pass
|
||||
|
||||
|
||||
register_fused_layer_norm()
|
|
@ -0,0 +1,43 @@
|
|||
import torch
|
||||
import xformers.ops as xops
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.elixir.tracer.memory_tracer.memory_tensor import MTensor
|
||||
|
||||
|
||||
def lower_triangular_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, p: float = 0.0):
|
||||
|
||||
args = (query, key, value)
|
||||
meta_flag = False
|
||||
|
||||
for x in args:
|
||||
if x.device.type == 'meta':
|
||||
meta_flag = True
|
||||
break
|
||||
|
||||
if meta_flag:
|
||||
atten = query @ key.transpose(-2, -1)
|
||||
output = atten @ value
|
||||
return output
|
||||
|
||||
profile_flag = False
|
||||
|
||||
def to_torch_tensor(x):
|
||||
if isinstance(x, MTensor):
|
||||
nonlocal profile_flag
|
||||
profile_flag = True
|
||||
return x.elem
|
||||
return x
|
||||
|
||||
args = tree_map(to_torch_tensor, args)
|
||||
query, key, value = args
|
||||
output = xops.memory_efficient_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
p=p,
|
||||
attn_bias=xops.LowerTriangularMask())
|
||||
|
||||
if profile_flag:
|
||||
output = MTensor(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder
|
||||
|
||||
from .gpt_attention import XGPT2Attention, XGPT2Model
|
||||
from .opt_attention import XOPTAttention, XOPTDecoder
|
||||
|
||||
|
||||
def wrap_attention(model: nn.Module):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, GPT2Model):
|
||||
module.__class__ = XGPT2Model
|
||||
elif isinstance(module, GPT2Attention):
|
||||
module.__class__ = XGPT2Attention
|
||||
elif isinstance(module, OPTAttention):
|
||||
module.__class__ = XOPTAttention
|
||||
elif isinstance(module, OPTDecoder):
|
||||
module.__class__ = XOPTDecoder
|
||||
return model
|
|
@ -0,0 +1,45 @@
|
|||
import einops
|
||||
import torch
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
|
||||
|
||||
from .attention import lower_triangular_attention
|
||||
|
||||
|
||||
class XGPT2Attention(GPT2Attention):
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
assert self.scale_attn_weights
|
||||
assert not self.is_cross_attention
|
||||
assert not self.scale_attn_by_inverse_layer_idx
|
||||
assert not self.reorder_and_upcast_attn
|
||||
|
||||
b_size, h_size, m_size, k_size = query.size()
|
||||
|
||||
assert self.bias.size(-1) == m_size
|
||||
query = einops.rearrange(query, 'b h m k -> b m h k')
|
||||
key = einops.rearrange(key, 'b h m k -> b m h k')
|
||||
value = einops.rearrange(value, 'b h m k -> b m h k')
|
||||
|
||||
drop_rate = self.attn_dropout.p
|
||||
output = lower_triangular_attention(query, key, value, p=drop_rate)
|
||||
|
||||
ret = einops.rearrange(output, 'b m h k -> b h m k')
|
||||
|
||||
return ret, None
|
||||
|
||||
|
||||
class XGPT2Model(GPT2Model):
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg'
|
||||
attn_mask = kwargs.get('attention_mask')
|
||||
# assert torch.all(attn_mask == 1), 'only accept no padding mask'
|
||||
|
||||
head_mask = kwargs.get('head_mask', None)
|
||||
assert head_mask is None, 'head mask should be None'
|
||||
|
||||
output_attn = kwargs.get('output_attentions', False)
|
||||
if output_attn:
|
||||
Warning('output_attentions is not supported for XGPT2Model')
|
||||
|
||||
return super().forward(*args, **kwargs)
|
|
@ -0,0 +1,10 @@
|
|||
from apex.normalization.fused_layer_norm import fused_layer_norm, fused_layer_norm_affine
|
||||
|
||||
|
||||
def ln_func(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if weight is None:
|
||||
assert bias is None
|
||||
return fused_layer_norm(input, normalized_shape, eps)
|
||||
else:
|
||||
assert weight is not None and bias is not None
|
||||
return fused_layer_norm_affine(input, weight, bias, normalized_shape, eps)
|
|
@ -0,0 +1,86 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder
|
||||
|
||||
from .attention import lower_triangular_attention
|
||||
|
||||
|
||||
class XOPTAttention(OPTAttention):
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
assert is_cross_attention is False
|
||||
assert past_key_value is None
|
||||
assert layer_head_mask is None
|
||||
# assert output_attentions is False
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
# get key, value proj
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
assert tgt_len == src_len
|
||||
|
||||
attn_output = lower_triangular_attention(query=query_states, key=key_states, value=value_states, p=self.dropout)
|
||||
|
||||
if attn_output.size() != (bsz, tgt_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(f'`attn_output` should be of size {(bsz, tgt_len, self.num_heads, self.head_dim)}, but is'
|
||||
f' {attn_output.size()}')
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
class XOPTDecoder(OPTDecoder):
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
assert 'attention_mask' in kwargs, 'please pass attention_mask as a kwarg'
|
||||
attn_mask = kwargs.get('attention_mask')
|
||||
# assert torch.all(attn_mask == 1), 'only accept no padding mask'
|
||||
|
||||
head_mask = kwargs.get('head_mask', None)
|
||||
assert head_mask is None, 'head mask should be None'
|
||||
|
||||
output_attn = kwargs.get('output_attentions', False)
|
||||
if output_attn:
|
||||
Warning('output_attentions is not supported for XOPTDecoder')
|
||||
|
||||
return super().forward(*args, **kwargs)
|
|
@ -0,0 +1,4 @@
|
|||
from .mini_waste import minimum_waste_search
|
||||
from .optimal import optimal_search
|
||||
from .result import ChunkPlan, SearchResult
|
||||
from .simple import simple_search
|
|
@ -0,0 +1,135 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool
|
||||
from colossalai.elixir.tracer.param_tracer import generate_tf_order
|
||||
from colossalai.elixir.tracer.utils import meta_copy
|
||||
from colossalai.elixir.utils import print_rank_0
|
||||
|
||||
from .result import ChunkPlan
|
||||
from .utils import to_meta_tensor
|
||||
|
||||
|
||||
class SearchBase(ABC):
|
||||
"""A basic class for search algorithms.
|
||||
|
||||
args:
|
||||
module: the model to be searched
|
||||
dtype: the unified dtype of all parameters
|
||||
prefetch: whether to prefetch chunks during training
|
||||
verbose: whether to print search details
|
||||
inp: a dictionary, the example input of the model
|
||||
step_fn: the example step function of the model
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
dtype: torch.dtype = torch.float,
|
||||
prefetch: bool = False,
|
||||
verbose: bool = False,
|
||||
inp=None,
|
||||
step_fn=None) -> None:
|
||||
|
||||
self.unified_dtype = dtype
|
||||
self.meta_module = meta_copy(module, partial(to_meta_tensor, dtype=self.unified_dtype))
|
||||
self.prefetch_flag = prefetch
|
||||
self.verbose = verbose
|
||||
self.param_to_name = {param: name for name, param in self.meta_module.named_parameters()}
|
||||
|
||||
self.public_block_size = 1024
|
||||
self.public_block_number = 0
|
||||
|
||||
self.param_per_step = None
|
||||
self.max_checkpoint_size = 0
|
||||
if self.prefetch_flag:
|
||||
assert inp is not None and step_fn is not None
|
||||
tf_running_info = generate_tf_order(self.meta_module, inp, step_fn, dtype)
|
||||
self.param_per_step = tf_running_info.get('params_per_step')
|
||||
if self.verbose:
|
||||
print_rank_0('Prefetch enabled: the called order of parameters')
|
||||
for i, step in enumerate(self.param_per_step):
|
||||
print_rank_0(f'step {i}: {step}')
|
||||
|
||||
name_to_param = {name: param for name, param in self.meta_module.named_parameters()}
|
||||
for checkpoint in tf_running_info.get('checkpoint_info'):
|
||||
sum_numel = 0
|
||||
for i in range(*checkpoint):
|
||||
for name in self.param_per_step[i]:
|
||||
param = name_to_param[name]
|
||||
sum_numel += param.numel()
|
||||
self.max_checkpoint_size = max(self.max_checkpoint_size, sum_numel)
|
||||
if self.verbose:
|
||||
print_rank_0(f'checkpoint infomation: from-to -> {checkpoint}, numel -> {sum_numel}')
|
||||
|
||||
@abstractmethod
|
||||
def private_truncate(self, param: nn.Parameter) -> int:
|
||||
"""A function used to truncate the length of a private chunk,
|
||||
which only contains one parameter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def public_trucate(self, length: int) -> int:
|
||||
"""A function used to trucate the length of all publick chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs) -> Tuple:
|
||||
"""The core search function. It returns a tuple of a private group and public groups.
|
||||
"""
|
||||
pass
|
||||
|
||||
def generate_chunk_plans(self, private_group, publick_groups) -> List[ChunkPlan]:
|
||||
plans = list()
|
||||
for param in private_group:
|
||||
chunk_size = self.private_truncate(param)
|
||||
chunk_dtype = param.dtype
|
||||
chunk_kwargs = dict(rcache_fused=True)
|
||||
chunk_plan = ChunkPlan(name_list=[self.param_to_name[param]],
|
||||
chunk_size=chunk_size,
|
||||
chunk_dtype=chunk_dtype,
|
||||
kwargs=chunk_kwargs)
|
||||
plans.append(chunk_plan)
|
||||
|
||||
self.public_block_size = self.public_trucate(self.public_block_size)
|
||||
public_chunk_size = self.public_block_size
|
||||
public_chunk_dtype = self.unified_dtype
|
||||
for group in publick_groups:
|
||||
chunk_kwargs = {}
|
||||
chunk_plan = ChunkPlan(name_list=[self.param_to_name[p] for p in group],
|
||||
chunk_size=public_chunk_size,
|
||||
chunk_dtype=public_chunk_dtype,
|
||||
kwargs=chunk_kwargs)
|
||||
plans.append(chunk_plan)
|
||||
|
||||
if self.verbose:
|
||||
print_rank_0(f'Chunk plans: total {len(plans)} chunks')
|
||||
for i, plan in enumerate(plans):
|
||||
print_rank_0(f'plan {i}: {plan}')
|
||||
|
||||
return plans
|
||||
|
||||
def allocate_chunk_group(self, chunk_plans: List[ChunkPlan]) -> ChunkGroup:
|
||||
block_require_list = list()
|
||||
for plan in chunk_plans:
|
||||
kwargs = plan.kwargs
|
||||
if kwargs.get('rcache_fused', False):
|
||||
block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype))
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_dtype=self.unified_dtype,
|
||||
public_block_size=self.public_block_size,
|
||||
public_block_number=self.public_block_number,
|
||||
private_block_list=block_require_list)
|
||||
|
||||
if self.verbose:
|
||||
print_rank_0(
|
||||
f'Memory pool (rcache): {mp}\n\tblock size -> {mp.public_block_size}, block number -> {mp.public_free_cnt}'
|
||||
)
|
||||
|
||||
return ChunkGroup(mp)
|
|
@ -0,0 +1,167 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.utils import print_rank_0
|
||||
|
||||
from .base import SearchBase
|
||||
from .result import SearchResult
|
||||
from .utils import find_minimum_waste_size, find_search_range, get_multi_used_params, to_divide
|
||||
|
||||
dtype_to_es = {torch.float16: 2, torch.float32: 4, torch.float64: 8}
|
||||
|
||||
|
||||
class SearchMiniWaste(SearchBase):
|
||||
"""Search the best chunk size to minimize the waste of memory.
|
||||
|
||||
args:
|
||||
module: the module to be searched
|
||||
default_group_size: the default group size of communications
|
||||
dtype: the data type of the parameters
|
||||
prefetch: whether to prefetch the parameters
|
||||
verbose: whether to print the search details
|
||||
inp: a dictionary, the example input of the model
|
||||
step_fn: the example step function of training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
default_group_size: int,
|
||||
dtype: torch.dtype = torch.float,
|
||||
prefetch: bool = False,
|
||||
verbose: bool = False,
|
||||
inp=None,
|
||||
step_fn=None) -> None:
|
||||
|
||||
super().__init__(module, dtype, prefetch, verbose, inp, step_fn)
|
||||
self.default_group_size = default_group_size
|
||||
|
||||
def private_truncate(self, param: nn.Parameter) -> int:
|
||||
return to_divide(param.numel(), self.default_group_size)
|
||||
|
||||
def public_trucate(self, length: int) -> int:
|
||||
return to_divide(length, self.default_group_size)
|
||||
|
||||
def search(self) -> Tuple:
|
||||
min_chunk_size, max_chunk_size, search_interval = find_search_range(self.meta_module)
|
||||
# get multi-used parameters
|
||||
private_params = get_multi_used_params(self.meta_module)
|
||||
# get parameters used only one time
|
||||
public_params = [p for p in self.meta_module.parameters() if p not in private_params]
|
||||
# collect the number of elements of each parameter
|
||||
public_numels = [p.numel() for p in public_params]
|
||||
# calculate the sumary of all parameters
|
||||
total_size = sum(public_numels)
|
||||
|
||||
if total_size <= min_chunk_size:
|
||||
public_block_size = total_size
|
||||
waste_size = 0
|
||||
else:
|
||||
public_block_size, waste_size = find_minimum_waste_size(
|
||||
# pre-commit: do not rearrange
|
||||
numel_group_list=[public_numels],
|
||||
min_range=min_chunk_size,
|
||||
max_range=max_chunk_size,
|
||||
interval=search_interval)
|
||||
|
||||
if self.verbose:
|
||||
if total_size == 0:
|
||||
waste_percentage = 0
|
||||
else:
|
||||
waste_percentage = 100 * waste_size / total_size
|
||||
print_rank_0(
|
||||
f'Minimum waste search result: chunk size = {public_block_size}, waste percentage = {waste_percentage: .1f} %'
|
||||
)
|
||||
|
||||
# initialize the mapping from parameters to chunks
|
||||
param_to_chunk_id = dict()
|
||||
chunk_id = 0
|
||||
# deal with private parameters
|
||||
for p in private_params:
|
||||
param_to_chunk_id[p] = chunk_id
|
||||
chunk_id += 1
|
||||
# record the upper bound
|
||||
private_id_upperbound = chunk_id
|
||||
# deal with public parameters
|
||||
last_left = 0
|
||||
for p in public_params:
|
||||
p_size = p.numel()
|
||||
|
||||
if last_left < p_size:
|
||||
last_left = public_block_size
|
||||
chunk_id += 1
|
||||
|
||||
assert last_left >= p_size
|
||||
|
||||
last_left -= p_size
|
||||
param_to_chunk_id[p] = chunk_id
|
||||
|
||||
# initailize public groups
|
||||
public_number_chunks = chunk_id - private_id_upperbound
|
||||
public_groups = [[] for _ in range(public_number_chunks)]
|
||||
for p in public_params:
|
||||
public_chunk_id = param_to_chunk_id[p] - private_id_upperbound - 1
|
||||
public_groups[public_chunk_id].append(p)
|
||||
|
||||
# calculate the number of minimum chunks allocated in R cache
|
||||
max_lived_chunks = 0
|
||||
for module in self.meta_module.modules():
|
||||
param_set = set()
|
||||
for param in module.parameters(recurse=False):
|
||||
param_set.add(param_to_chunk_id[param])
|
||||
max_lived_chunks = max(max_lived_chunks, len(param_set))
|
||||
# allocate more chunks for prefetch
|
||||
if self.prefetch_flag:
|
||||
max_lived_chunks = min(max_lived_chunks + 4, public_number_chunks)
|
||||
|
||||
if total_size == 0:
|
||||
max_lived_chunks = 0
|
||||
|
||||
self.public_block_size = public_block_size
|
||||
self.public_block_number = max_lived_chunks
|
||||
|
||||
return (private_params, public_groups)
|
||||
|
||||
|
||||
def minimum_waste_search(m: nn.Module,
|
||||
group_size: int,
|
||||
unified_dtype: torch.dtype = torch.float,
|
||||
cpu_offload: bool = False,
|
||||
prefetch: bool = False,
|
||||
verbose: bool = False,
|
||||
pin_memory: bool = True,
|
||||
inp=None,
|
||||
step_fn=None) -> SearchResult:
|
||||
|
||||
search_class = SearchMiniWaste(
|
||||
# pre-commit: do not rearrange
|
||||
module=m,
|
||||
default_group_size=group_size,
|
||||
dtype=unified_dtype,
|
||||
prefetch=prefetch,
|
||||
verbose=verbose,
|
||||
inp=inp,
|
||||
step_fn=step_fn)
|
||||
|
||||
private_group, public_groups = search_class.search()
|
||||
chunk_plans = search_class.generate_chunk_plans(private_group, public_groups)
|
||||
|
||||
# assign shard device
|
||||
if cpu_offload:
|
||||
shard_device = torch.device('cpu')
|
||||
else:
|
||||
shard_device = gpu_device()
|
||||
|
||||
for plan in chunk_plans:
|
||||
plan.kwargs['shard_device'] = shard_device
|
||||
if cpu_offload:
|
||||
plan.kwargs['cpu_pin_memory'] = pin_memory
|
||||
|
||||
chunk_group = search_class.allocate_chunk_group(chunk_plans)
|
||||
|
||||
return SearchResult(chunk_group=chunk_group,
|
||||
chunk_plans=chunk_plans,
|
||||
param_called_per_step=search_class.param_per_step)
|
|
@ -0,0 +1,284 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
|
||||
from colossalai.elixir.cuda import get_allowed_memory, gpu_device
|
||||
from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling
|
||||
from colossalai.elixir.utils import calc_buffer_size, print_rank_0
|
||||
|
||||
from .base import SearchBase
|
||||
from .result import SearchResult
|
||||
from .simulator import find_optimal_chunk_size, rcache_prioirity_check
|
||||
from .utils import find_search_range, get_multi_used_params, to_divide
|
||||
|
||||
dtype_to_es = {torch.float16: 2, torch.float32: 4, torch.float64: 8}
|
||||
|
||||
|
||||
class SearchOptimal(SearchBase):
|
||||
"""Search the best chunk size to maximize the training throughput.
|
||||
Users should provide the example input data and step function of training.
|
||||
|
||||
args:
|
||||
module: the module to be searched
|
||||
default_group_size: the default group size of communications
|
||||
activation_fragment_factor: the factor to estimate the total activation memory usage
|
||||
allocation_fragment_factor: the factor to estimate the effective ratio of the memory usage can be used by Elixir
|
||||
driver_usage: the memory usage of the cuda driver
|
||||
dtype: the data type of the parameters
|
||||
verbose: whether to print the search details
|
||||
overlap: whether to overlap the communication and computation
|
||||
inp: a dictionary, the example input of the model
|
||||
step_fn: the example step function of training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
default_group_size: int,
|
||||
activation_fragment_factor: float = 1.25,
|
||||
allocation_fragment_factor: float = 0.95,
|
||||
driver_usage: float = 2 * 1024**3,
|
||||
dtype: torch.dtype = torch.float,
|
||||
verbose: bool = False,
|
||||
overlap: bool = False,
|
||||
inp=None,
|
||||
step_fn=None) -> None:
|
||||
# as for optimal search, we must profile the model first
|
||||
super().__init__(module, dtype, True, verbose, inp, step_fn)
|
||||
# profile cuda memory usage
|
||||
memo_usage = cuda_memory_profiling(model=self.meta_module, inp=inp, step_fn=step_fn, dtype=dtype)
|
||||
torch.cuda.empty_cache()
|
||||
buffer_occ = memo_usage['buffer_occ']
|
||||
# get the maximum memory usage of activation
|
||||
predict_activation = memo_usage['activation_occ']
|
||||
# calculate the total capacity of the current device
|
||||
gpu_memory = get_allowed_memory()
|
||||
# allowed capacity = allocation_fragment_factor * (total capacity - activation_fragment_factor * activation)
|
||||
self.cuda_capacity = int(
|
||||
allocation_fragment_factor *
|
||||
(gpu_memory - driver_usage - buffer_occ - activation_fragment_factor * predict_activation))
|
||||
hook_buffer_store_size = calc_buffer_size(m=self.meta_module, test_dtype=self.unified_dtype)
|
||||
self.cuda_elements = self.cuda_capacity // dtype_to_es.get(dtype) - hook_buffer_store_size
|
||||
|
||||
if self.cuda_elements < 0:
|
||||
raise RuntimeError('optimal search: activation is too large, please reduce batch size')
|
||||
|
||||
if self.verbose:
|
||||
print_rank_0('Predict memory usage:')
|
||||
for k, v in memo_usage.items():
|
||||
print_rank_0(f'{k}: {_format_memory(v)}')
|
||||
print_rank_0(f'allowed allocation space: {_format_memory(self.cuda_capacity)}')
|
||||
print_rank_0(f'hook buffer store size: {hook_buffer_store_size}')
|
||||
print_rank_0(f'allowed {dtype} elements: {self.cuda_elements}')
|
||||
|
||||
self.default_group_size = default_group_size
|
||||
self.comm_overlap = overlap
|
||||
|
||||
def private_truncate(self, param: nn.Parameter) -> int:
|
||||
return to_divide(param.numel(), self.default_group_size)
|
||||
|
||||
def public_trucate(self, length: int) -> int:
|
||||
return to_divide(length, self.default_group_size)
|
||||
|
||||
def search(self) -> Tuple:
|
||||
min_chunk_size, max_chunk_size, search_interval = find_search_range(self.meta_module)
|
||||
# get multi-used parameters
|
||||
private_params = get_multi_used_params(self.meta_module)
|
||||
# subtract the footprint of fused parameters
|
||||
for param in private_params:
|
||||
self.cuda_elements -= param.numel()
|
||||
if self.cuda_elements < 0:
|
||||
raise RuntimeError('optimal search: no enough space for fused parameters')
|
||||
|
||||
# initialize public params in the called order
|
||||
public_params = list()
|
||||
public_param_set = set()
|
||||
name_to_param = {name: param for name, param in self.meta_module.named_parameters()}
|
||||
for name_set in self.param_per_step:
|
||||
for name in name_set:
|
||||
param = name_to_param.get(name)
|
||||
if param in private_params or param in public_param_set:
|
||||
continue
|
||||
public_params.append(param)
|
||||
public_param_set.add(param)
|
||||
del name_to_param
|
||||
del public_param_set
|
||||
|
||||
# collect the number of elements of each parameter
|
||||
public_numels = [p.numel() for p in public_params]
|
||||
# calculate the sumary of all parameters
|
||||
total_size = sum(public_numels)
|
||||
# collect the name for each public parameters
|
||||
public_param_names = [self.param_to_name[p] for p in public_params]
|
||||
|
||||
if total_size <= min_chunk_size:
|
||||
public_block_size = total_size
|
||||
n_blocks = 1
|
||||
waste_size = 0
|
||||
else:
|
||||
public_block_size, n_blocks, waste_size = find_optimal_chunk_size(
|
||||
# pre-commit: do not rearrange
|
||||
param_per_step=self.param_per_step,
|
||||
param_names=public_param_names,
|
||||
param_numels=public_numels,
|
||||
cuda_elements=self.cuda_elements,
|
||||
overlap=self.comm_overlap,
|
||||
min_range=min_chunk_size,
|
||||
max_range=max_chunk_size,
|
||||
interval=search_interval)
|
||||
# truncate the size of public blocks
|
||||
public_block_size = self.public_trucate(public_block_size)
|
||||
if self.cuda_elements < n_blocks * public_block_size:
|
||||
raise RuntimeError('no enough space for unfused parameters')
|
||||
|
||||
if self.verbose:
|
||||
if total_size == 0:
|
||||
waste_percentage = 0
|
||||
else:
|
||||
waste_percentage = 100 * waste_size / total_size
|
||||
print_rank_0(
|
||||
f'Optimal search result: chunk size = {public_block_size}, waste percentage = {waste_percentage: .1f} %'
|
||||
)
|
||||
|
||||
# initialize the mapping from parameters to chunks
|
||||
param_to_chunk_id = dict()
|
||||
chunk_id = 0
|
||||
# deal with private parameters
|
||||
for p in private_params:
|
||||
param_to_chunk_id[p] = chunk_id
|
||||
chunk_id += 1
|
||||
# record the upper bound
|
||||
private_id_upperbound = chunk_id
|
||||
# deal with public parameters
|
||||
last_left = 0
|
||||
for p in public_params:
|
||||
p_size = p.numel()
|
||||
|
||||
if last_left < p_size:
|
||||
last_left = public_block_size
|
||||
chunk_id += 1
|
||||
|
||||
assert last_left >= p_size
|
||||
|
||||
last_left -= p_size
|
||||
param_to_chunk_id[p] = chunk_id
|
||||
|
||||
# initailize public groups
|
||||
public_number_chunks = chunk_id - private_id_upperbound
|
||||
public_groups = [[] for _ in range(public_number_chunks)]
|
||||
for p in public_params:
|
||||
public_chunk_id = param_to_chunk_id[p] - private_id_upperbound - 1
|
||||
public_groups[public_chunk_id].append(p)
|
||||
|
||||
if total_size == 0:
|
||||
n_blocks = 0
|
||||
|
||||
self.public_block_size = public_block_size
|
||||
self.public_block_number = n_blocks
|
||||
|
||||
return (private_params, public_groups)
|
||||
|
||||
def configure_rcache_size(self, chunk_plans: list, os_factor: int):
|
||||
element_os = 4
|
||||
if self.unified_dtype == torch.float16:
|
||||
element_pa = 2
|
||||
elif self.unified_dtype == torch.float:
|
||||
element_pa = 4
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
priority = rcache_prioirity_check(n=self.default_group_size, r_os=os_factor, e_p=element_pa, e_o=element_os)
|
||||
if self.verbose:
|
||||
print_rank_0(f'rCache Priority Check: {priority}')
|
||||
|
||||
if not priority:
|
||||
n_cache_blocks = max(4, math.ceil(self.max_checkpoint_size / self.public_block_size) + 1)
|
||||
if self.comm_overlap:
|
||||
n_cache_blocks += 2
|
||||
n_cache_blocks = min(n_cache_blocks, self.public_block_number)
|
||||
self.cuda_elements -= n_cache_blocks * self.public_block_size
|
||||
if self.verbose:
|
||||
print_rank_0(f'n_cache_block is set to {n_cache_blocks}')
|
||||
else:
|
||||
self.cuda_elements -= self.public_block_number * self.public_block_size
|
||||
|
||||
def try_move_chunk_to_cuda(fused: bool):
|
||||
for (i, plan) in enumerate(chunk_plans):
|
||||
rcache_fused = plan.kwargs.get('rcache_fused', False)
|
||||
if not fused and rcache_fused:
|
||||
continue
|
||||
elif fused and not rcache_fused:
|
||||
break
|
||||
param_os_size = os_factor * plan.chunk_size // self.default_group_size
|
||||
if self.cuda_elements >= param_os_size:
|
||||
plan.kwargs['shard_device'] = gpu_device()
|
||||
self.cuda_elements -= param_os_size
|
||||
else:
|
||||
plan.kwargs['shard_device'] = torch.device('cpu')
|
||||
plan.kwargs['cpu_pin_memory'] = True
|
||||
if self.verbose:
|
||||
print_rank_0(f"chunk {i}: shard device -> {plan.kwargs['shard_device']}")
|
||||
|
||||
# check chunks that are not fused on rCache
|
||||
try_move_chunk_to_cuda(False)
|
||||
# check chunks that are fused on rCache
|
||||
try_move_chunk_to_cuda(True)
|
||||
|
||||
if not priority:
|
||||
extra_blocks = math.floor(self.cuda_elements / self.public_block_size)
|
||||
extra_blocks = min(extra_blocks, self.public_block_number - n_cache_blocks)
|
||||
self.cuda_elements -= extra_blocks * self.public_block_size
|
||||
self.public_block_number = n_cache_blocks + extra_blocks
|
||||
if self.verbose:
|
||||
print_rank_0(f'n_extra_blocks is set to {extra_blocks}')
|
||||
|
||||
return chunk_plans
|
||||
|
||||
|
||||
def optimal_search(
|
||||
# pre-commit: do not rearrange
|
||||
m: nn.Module,
|
||||
group_size: int,
|
||||
unified_dtype: torch.dtype = torch.float,
|
||||
optimizer_type: str = 'Adam',
|
||||
overlap: bool = False,
|
||||
verbose: bool = False,
|
||||
inp=None,
|
||||
step_fn=None) -> SearchResult:
|
||||
|
||||
search_class = SearchOptimal(
|
||||
# pre-commit: do not rearrange
|
||||
module=m,
|
||||
default_group_size=group_size,
|
||||
dtype=unified_dtype,
|
||||
verbose=verbose,
|
||||
overlap=overlap,
|
||||
inp=inp,
|
||||
step_fn=step_fn)
|
||||
|
||||
private_group, public_groups = search_class.search()
|
||||
chunk_plans = search_class.generate_chunk_plans(private_group, public_groups)
|
||||
|
||||
if unified_dtype == torch.float16:
|
||||
master_weight_factor = 2
|
||||
elif unified_dtype == torch.float:
|
||||
master_weight_factor = 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if optimizer_type == 'SGD':
|
||||
extra_sotre_factor = 1
|
||||
elif optimizer_type == 'Adam':
|
||||
extra_sotre_factor = 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
os_factor = 1 + (1 + extra_sotre_factor) * master_weight_factor
|
||||
chunk_plans = search_class.configure_rcache_size(chunk_plans, os_factor)
|
||||
chunk_group = search_class.allocate_chunk_group(chunk_plans)
|
||||
|
||||
return SearchResult(chunk_group=chunk_group,
|
||||
chunk_plans=chunk_plans,
|
||||
param_called_per_step=search_class.param_per_step)
|
|
@ -0,0 +1,32 @@
|
|||
from typing import Dict, List, NamedTuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.chunk import ChunkGroup
|
||||
|
||||
|
||||
class ChunkPlan(NamedTuple):
|
||||
"""ChunkPlan is a type of configuration used to instruct the initialization of a chunk.
|
||||
|
||||
args:
|
||||
name_list: contains the names of parameters that should be pushed into this chunk
|
||||
chunk_size: the size of this chunk
|
||||
chunk_dtype: the dtype of this chunk
|
||||
kwargs: a dictionary used in __init__ function of Chunk
|
||||
"""
|
||||
name_list: List[str]
|
||||
chunk_size: int
|
||||
chunk_dtype: torch.dtype
|
||||
kwargs: Dict
|
||||
|
||||
|
||||
class SearchResult(object):
|
||||
|
||||
def __init__(self,
|
||||
chunk_group: ChunkGroup,
|
||||
chunk_plans: List[ChunkPlan],
|
||||
param_called_per_step: List[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.chunk_group = chunk_group
|
||||
self.param_chunk_plans = chunk_plans
|
||||
self.param_called_per_step = param_called_per_step
|
|
@ -0,0 +1,120 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .base import SearchBase
|
||||
from .result import SearchResult
|
||||
from .utils import get_multi_used_params, to_divide
|
||||
|
||||
|
||||
class SearchSimple(SearchBase):
|
||||
"""The simple search algorithm used for unit tests.
|
||||
Developers can specify the number of chunks used.
|
||||
|
||||
args:
|
||||
module: the module to be searched
|
||||
default_group_size: the default group size of communications
|
||||
dtype: the data type of the parameters
|
||||
prefetch: whether to prefetch the chunks
|
||||
verbose: whether to print the search process
|
||||
inp: the example input of the model
|
||||
step_fn: the example step function of training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
default_group_size: int,
|
||||
dtype: torch.dtype = torch.float,
|
||||
prefetch: bool = False,
|
||||
verbose: bool = False,
|
||||
inp=None,
|
||||
step_fn=None) -> None:
|
||||
|
||||
super().__init__(module, dtype, prefetch, verbose, inp, step_fn)
|
||||
self.default_group_size = default_group_size
|
||||
|
||||
def private_truncate(self, param: nn.Parameter) -> int:
|
||||
return to_divide(param.numel(), self.default_group_size)
|
||||
|
||||
def public_trucate(self, length: int) -> int:
|
||||
return to_divide(length, self.default_group_size)
|
||||
|
||||
def search(self, split_number: int, allocate_factor: float) -> Tuple:
|
||||
# get multi-used parameters
|
||||
private_params = get_multi_used_params(self.meta_module)
|
||||
# get parameters used only one time
|
||||
public_params = [p for p in self.meta_module.parameters() if p not in private_params]
|
||||
|
||||
# calculate the size of each group
|
||||
len_public = len(public_params)
|
||||
split_number = min(len_public, split_number)
|
||||
# allocate a list for groups
|
||||
public_groups = list()
|
||||
if split_number > 0:
|
||||
average_size = len_public // split_number
|
||||
left_size = len_public % split_number
|
||||
|
||||
# set the size of each segment
|
||||
pack_size_list = [average_size] * split_number
|
||||
for i in range(split_number):
|
||||
if left_size > 0:
|
||||
pack_size_list[i] += 1
|
||||
left_size -= 1
|
||||
|
||||
# split public parameters
|
||||
for i in range(split_number):
|
||||
p_list = list()
|
||||
for _ in range(pack_size_list[i]):
|
||||
p = public_params.pop(0)
|
||||
p_list.append(p)
|
||||
public_groups.append(p_list)
|
||||
assert len(public_params) == 0
|
||||
|
||||
# calculate the maximum summarized size
|
||||
max_sum_size = 0
|
||||
for p_list in public_groups:
|
||||
sum_size = sum([p.numel() for p in p_list])
|
||||
max_sum_size = max(max_sum_size, sum_size)
|
||||
else:
|
||||
max_sum_size = 0
|
||||
|
||||
self.public_block_size = max_sum_size
|
||||
self.public_block_number = math.ceil(split_number * allocate_factor)
|
||||
|
||||
return (private_params, public_groups)
|
||||
|
||||
|
||||
def simple_search(m: nn.Module,
|
||||
group_size: int,
|
||||
split_number: int = 10,
|
||||
allocate_factor: float = 0.6,
|
||||
unified_dtype: torch.dtype = torch.float,
|
||||
shard_device: torch.device = torch.device('cpu'),
|
||||
prefetch: bool = False,
|
||||
verbose: bool = False,
|
||||
inp=None,
|
||||
step_fn=None) -> SearchResult:
|
||||
|
||||
search_class = SearchSimple(
|
||||
# pre-commit: do not rearrange
|
||||
module=m,
|
||||
default_group_size=group_size,
|
||||
dtype=unified_dtype,
|
||||
prefetch=prefetch,
|
||||
verbose=verbose,
|
||||
inp=inp,
|
||||
step_fn=step_fn)
|
||||
|
||||
private_group, public_groups = search_class.search(split_number, allocate_factor)
|
||||
chunk_plans = search_class.generate_chunk_plans(private_group, public_groups)
|
||||
# assign shard device
|
||||
for plan in chunk_plans:
|
||||
plan.kwargs['shard_device'] = shard_device
|
||||
|
||||
chunk_group = search_class.allocate_chunk_group(chunk_plans)
|
||||
|
||||
return SearchResult(chunk_group=chunk_group,
|
||||
chunk_plans=chunk_plans,
|
||||
param_called_per_step=search_class.param_per_step)
|
|
@ -0,0 +1,112 @@
|
|||
import math
|
||||
|
||||
from .utils import to_divide
|
||||
|
||||
|
||||
def calc_move_times(param_per_step: list, param_to_chunk: dict, n_blocks: int):
|
||||
from colossalai.elixir.simulator import move_count
|
||||
chunk_per_step = list()
|
||||
|
||||
for param_set in param_per_step:
|
||||
id_set = set()
|
||||
for name in param_set:
|
||||
# continue if the parameter is ignored
|
||||
if name not in param_to_chunk:
|
||||
continue
|
||||
id_set.add(param_to_chunk[name])
|
||||
if len(id_set) > 0:
|
||||
chunk_per_step.append(list(id_set))
|
||||
|
||||
return move_count(chunk_per_step, n_blocks)
|
||||
|
||||
|
||||
def find_optimal_chunk_size(
|
||||
# pre-commit: do not rearrange
|
||||
param_per_step: list,
|
||||
param_names: list,
|
||||
param_numels: list,
|
||||
cuda_elements: int,
|
||||
overlap: bool,
|
||||
min_range: int,
|
||||
max_range: int,
|
||||
interval: int):
|
||||
|
||||
max_numel = 0
|
||||
for numel in param_numels:
|
||||
max_numel = max(max_numel, numel)
|
||||
test_size = to_divide(max(max_numel, min_range), interval)
|
||||
# floor rounding
|
||||
cuda_elements = to_divide(cuda_elements - interval + 1, interval)
|
||||
max_range = min(max_range, cuda_elements)
|
||||
|
||||
min_move_elements = float('+inf')
|
||||
best_size = test_size
|
||||
best_number_blocks = 0
|
||||
best_waste = 0
|
||||
|
||||
def dispatch_chunks(param_to_chunk: dict, block_size: int) -> int:
|
||||
chunk_id = 0
|
||||
acc = 0
|
||||
left = 0
|
||||
for (name, numel) in zip(param_names, param_numels):
|
||||
if numel > left:
|
||||
acc += left
|
||||
chunk_id += 1
|
||||
left = block_size
|
||||
left -= numel
|
||||
param_to_chunk[name] = chunk_id
|
||||
return (chunk_id, left + acc)
|
||||
|
||||
assert test_size <= max_range, 'max_numel or min_range is larger than max_range or cuda capacity'
|
||||
while test_size <= max_range:
|
||||
# calculate the number of blocks
|
||||
number_blocks = int(cuda_elements // test_size)
|
||||
# if prefetch is enabled, we pretend that two chunks are reserved
|
||||
if overlap:
|
||||
number_blocks -= 2
|
||||
if number_blocks <= 0:
|
||||
continue
|
||||
# initialize the chunk id for each parameter
|
||||
param_to_chunk = dict()
|
||||
number_chunks, current_waste = dispatch_chunks(param_to_chunk, test_size)
|
||||
number_blocks = min(number_blocks, number_chunks)
|
||||
# calculate the minimum number of movements
|
||||
move_times = calc_move_times(param_per_step, param_to_chunk, number_blocks)
|
||||
|
||||
current_move_elements = move_times * test_size
|
||||
# print("test", test_size, current_move_elements)
|
||||
if current_move_elements < min_move_elements:
|
||||
min_move_elements = current_move_elements
|
||||
best_size = test_size
|
||||
best_number_blocks = number_blocks
|
||||
best_waste = current_waste
|
||||
|
||||
test_size += interval
|
||||
|
||||
if min_move_elements == float('inf'):
|
||||
raise RuntimeError('optimal search: can not find a valid solution')
|
||||
|
||||
return best_size, best_number_blocks, best_waste
|
||||
|
||||
|
||||
def bandwidth_c2g(n: int):
|
||||
return 16.3 * n + 8.7
|
||||
|
||||
|
||||
def bandwidth_g2c(n: int):
|
||||
return 15.8 * n + 2.3
|
||||
|
||||
|
||||
def velocity_gpu(n: int):
|
||||
return 50 * n
|
||||
|
||||
|
||||
def velocity_cpu(n: int):
|
||||
return 1.66 * math.log(n) + 5.15
|
||||
|
||||
|
||||
def rcache_prioirity_check(n: int, r_os: int, e_p: int, e_o: int):
|
||||
In = e_p / bandwidth_c2g(n) + e_p / bandwidth_g2c(n)
|
||||
Jn = (n / r_os) * (e_o / bandwidth_c2g(n) + In + e_p / bandwidth_g2c(n) + 1.0 / velocity_cpu(n) -
|
||||
1.0 / velocity_gpu(n))
|
||||
return In > Jn
|
|
@ -0,0 +1,108 @@
|
|||
from typing import List, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def to_divide(a: int, b: int):
|
||||
return a + (-a % b)
|
||||
|
||||
|
||||
def to_meta_tensor(t: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
|
||||
# only float tensors need dtype change
|
||||
if t.is_floating_point() and dtype is not None:
|
||||
meta_dtype = dtype
|
||||
else:
|
||||
meta_dtype = t.dtype
|
||||
# we shall not use t.data.to here, since t might be a fake tensor
|
||||
meta_t = torch.empty(t.size(), dtype=meta_dtype, device='meta')
|
||||
# pack it if t is a parameter
|
||||
# we should filter parameters with no grad
|
||||
if isinstance(t, nn.Parameter) and t.requires_grad:
|
||||
meta_t = nn.Parameter(meta_t)
|
||||
return meta_t
|
||||
|
||||
|
||||
def get_multi_used_params(m: nn.Module) -> Set[torch.Tensor]:
|
||||
multi_used_set = set()
|
||||
visit = dict()
|
||||
for module in m.modules():
|
||||
for param in module.parameters(recurse=False):
|
||||
if param not in visit:
|
||||
visit[param] = True
|
||||
else:
|
||||
multi_used_set.add(param)
|
||||
return multi_used_set
|
||||
|
||||
|
||||
def find_minimum_waste_size(numel_group_list: List[List[int]], min_range: int, max_range: int, interval: int):
|
||||
|
||||
max_per_group = list()
|
||||
for n_list in numel_group_list:
|
||||
max_per_group.append(max(n_list))
|
||||
max_numel = max(max_per_group)
|
||||
|
||||
test_size = to_divide(max(max_numel, min_range), interval)
|
||||
best_size = test_size
|
||||
min_waste = float('+inf')
|
||||
|
||||
def calc_waste(numel_list: List[int], block_size: int):
|
||||
acc = 0
|
||||
left = 0
|
||||
for s in numel_list:
|
||||
if s > left:
|
||||
acc += left
|
||||
left = block_size
|
||||
left -= s
|
||||
return left + acc
|
||||
|
||||
assert test_size <= max_range, 'max_numel or min_range is larger than max_range'
|
||||
while test_size <= max_range:
|
||||
current_waste = 0
|
||||
for n_list in numel_group_list:
|
||||
current_waste += calc_waste(n_list, test_size)
|
||||
if current_waste < min_waste:
|
||||
best_size = test_size
|
||||
min_waste = current_waste
|
||||
test_size += interval
|
||||
|
||||
return best_size, min_waste
|
||||
|
||||
|
||||
def find_search_range(m: nn.Module):
|
||||
|
||||
ele_size = 0
|
||||
for param in m.parameters():
|
||||
if ele_size == 0:
|
||||
ele_size = param.element_size()
|
||||
else:
|
||||
assert param.element_size() == ele_size
|
||||
|
||||
def next_2_pow(x: int):
|
||||
y = 1
|
||||
while y < x:
|
||||
y <<= 1
|
||||
return y
|
||||
|
||||
private_params = get_multi_used_params(m)
|
||||
params = [p for p in m.parameters() if p not in private_params]
|
||||
memo_list = [p.numel() * p.element_size() for p in params]
|
||||
max_memo = max(memo_list)
|
||||
# minimum chunk memory is 32 MiB
|
||||
default_min = 32 * 1024**2
|
||||
while default_min < max_memo:
|
||||
default_min <<= 1
|
||||
default_max = int(3 * default_min)
|
||||
# * 2 for forward and backward
|
||||
length = 2 * next_2_pow(len(params))
|
||||
default_iter_times = 16 * 1024**2
|
||||
default_search_times = default_iter_times // length
|
||||
|
||||
gap = default_max - default_min
|
||||
# minimum search interval is 1024
|
||||
if default_search_times > (gap // 1024):
|
||||
interval = 1024
|
||||
else:
|
||||
interval = gap // default_search_times
|
||||
|
||||
return (default_min // ele_size, default_max // ele_size, interval // ele_size)
|
|
@ -0,0 +1,61 @@
|
|||
#include <Python.h>
|
||||
#include <bits/stdc++.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
int move_count_impl(std::vector<std::vector<int>> &steps, int n_blocks) {
|
||||
int n_steps = steps.size();
|
||||
std::unordered_map<int, int> my_map;
|
||||
std::map<std::pair<int, int>, int> next_map;
|
||||
|
||||
for (auto i = n_steps - 1; ~i; --i) {
|
||||
auto ids = steps.at(i);
|
||||
for (auto c_id : ids) {
|
||||
auto iter = my_map.find(c_id);
|
||||
auto nxt = n_steps;
|
||||
if (iter != my_map.end()) nxt = iter->second;
|
||||
next_map.emplace(std::make_pair(i, c_id), nxt);
|
||||
my_map[c_id] = i;
|
||||
}
|
||||
}
|
||||
// reuse this map
|
||||
for (auto iter : my_map) my_map[iter.first] = 0;
|
||||
|
||||
int cache_size = 0, count = 0;
|
||||
std::priority_queue<std::pair<int, int>> cache;
|
||||
for (auto i = 0; i < n_steps; ++i) {
|
||||
auto ids = steps.at(i);
|
||||
assert(n_blocks >= ids.size());
|
||||
|
||||
int not_in = 0;
|
||||
for (auto c_id : ids)
|
||||
if (my_map[c_id] == 0) ++not_in;
|
||||
|
||||
while (cache_size + not_in > n_blocks) {
|
||||
std::pair<int, int> q_top = cache.top();
|
||||
cache.pop();
|
||||
assert(q_top.first > i);
|
||||
assert(my_map[q_top.second] == 1);
|
||||
my_map[q_top.second] = 0;
|
||||
--cache_size;
|
||||
++count;
|
||||
}
|
||||
|
||||
for (auto c_id : ids) {
|
||||
auto iter = next_map.find(std::make_pair(i, c_id));
|
||||
cache.push(std::make_pair(iter->second, c_id));
|
||||
if (my_map[c_id] == 0) {
|
||||
my_map[c_id] = 1;
|
||||
++cache_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
return (count + cache_size) << 1;
|
||||
}
|
||||
|
||||
int move_count(std::vector<std::vector<int>> &steps, int n_blocks) {
|
||||
return move_count_impl(steps, n_blocks);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("move_count", &move_count, "Count the number of moves.");
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
debug_flag = False
|
||||
|
||||
white_list = {torch.Tensor.__getitem__}
|
||||
|
||||
fake_allowed = {
|
||||
# pre-commit: don't move
|
||||
torch.Tensor.numel,
|
||||
torch.Tensor.size,
|
||||
torch.Tensor.stride,
|
||||
torch.Tensor.storage_offset,
|
||||
torch.Tensor.is_floating_point
|
||||
}
|
||||
|
||||
inpalce_mapping = {
|
||||
torch.Tensor.add_: torch.Tensor.add,
|
||||
torch.Tensor.sub_: torch.Tensor.sub,
|
||||
torch.Tensor.mul_: torch.Tensor.mul,
|
||||
torch.Tensor.div_: torch.Tensor.div
|
||||
}
|
||||
|
||||
|
||||
def is_no_hook_op(func) -> bool:
|
||||
if func.__name__.startswith('__') and func not in white_list:
|
||||
return True
|
||||
if func in fake_allowed:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class FakeTensor(torch.Tensor):
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_wrapper_subclass(cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
requires_grad=elem.requires_grad)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def to_outplace_tensor(t):
|
||||
if isinstance(t, OutplaceTensor):
|
||||
return t
|
||||
assert type(t) is torch.Tensor, f'type: {type(t)}'
|
||||
t.__class__ = OutplaceTensor
|
||||
return t
|
||||
|
||||
|
||||
class OutplaceTensor(torch.Tensor):
|
||||
# TODO: rename this class
|
||||
def __new__(cls, tensor):
|
||||
rt = tensor.as_subclass(cls)
|
||||
return rt
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# in order to trigger pre-op hook in the forward of checkpoint module
|
||||
# we have to capture the `backward` function
|
||||
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
||||
if func is torch.Tensor.backward:
|
||||
assert len(args) == 1 # only has 1 paramter
|
||||
backward_tensor = torch.Tensor(args[0])
|
||||
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
||||
return backward_tensor.backward(**tensor_kwargs)
|
||||
# return a tensor if the output needs to be a torch.Tensor (such as Tensor.data.__get__)
|
||||
if is_no_hook_op(func):
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
# debug inplace operations
|
||||
if debug_flag:
|
||||
if func.__name__.endswith('_'):
|
||||
print(f'found inplace operation {func.__name__}')
|
||||
|
||||
# replace the in-place function
|
||||
if func in inpalce_mapping:
|
||||
func = inpalce_mapping[func]
|
||||
# set the 'inplace' kwargs to False
|
||||
if 'inplace' in kwargs:
|
||||
kwargs['inplace'] = False
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
if not isinstance(ret, tuple):
|
||||
ret = (ret,)
|
||||
|
||||
def convert(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = to_outplace_tensor(t)
|
||||
return t
|
||||
|
||||
ret = tree_map(convert, ret)
|
||||
|
||||
if len(ret) == 1:
|
||||
ret = ret[0]
|
||||
|
||||
return ret
|
|
@ -0,0 +1,2 @@
|
|||
from .cuda_profiler import cuda_memory_profiling
|
||||
from .memory_tensor import MTensor
|
|
@ -0,0 +1,77 @@
|
|||
import gc
|
||||
from typing import Callable, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.elixir.tracer.utils import get_cuda_allocated, meta_copy, model_memory_figure
|
||||
from colossalai.elixir.utils import print_rank_0
|
||||
|
||||
from .memory_tensor import MTensor
|
||||
|
||||
|
||||
def grad_cleaner(grad):
|
||||
empty_grad = torch.empty_like(grad.elem)
|
||||
grad.elem = None
|
||||
empty_grad.storage().resize_(0)
|
||||
return empty_grad
|
||||
|
||||
|
||||
def cuda_memory_profiling(model: nn.Module, inp: Dict, step_fn: Callable, dtype=torch.float):
|
||||
assert isinstance(inp, dict), 'the example input should be a dictionary'
|
||||
print_rank_0(f'You are profiling cuda memory with dtype `{dtype}`')
|
||||
|
||||
def tensor_trans(t: torch.Tensor):
|
||||
# set dtype for tensors
|
||||
meta_dtype = dtype if t.is_floating_point() else t.dtype
|
||||
meta_t = torch.empty_like(t.data, device='meta', dtype=meta_dtype)
|
||||
# pack parameters
|
||||
if isinstance(t, nn.Parameter):
|
||||
meta_t = nn.Parameter(meta_t)
|
||||
return meta_t
|
||||
|
||||
# first, transform the model into one dtype
|
||||
model = meta_copy(model, tensor_trans)
|
||||
# get the memory firgure of the model
|
||||
memo_dict = model_memory_figure(model)
|
||||
# initialize a empty pool for parameters
|
||||
pool = torch.zeros(memo_dict['param_max_numel'], device='cuda', dtype=dtype)
|
||||
|
||||
def tensor_to_cuda(t):
|
||||
if isinstance(t, nn.Parameter):
|
||||
fake_data = pool[:t.numel()].view(t.shape)
|
||||
return nn.Parameter(fake_data)
|
||||
else:
|
||||
fake_data = torch.zeros(t.shape, device='cuda', dtype=t.dtype)
|
||||
return fake_data
|
||||
|
||||
# make all parameters in CUDA and point to a same address
|
||||
model = meta_copy(model, tensor_to_cuda)
|
||||
# add hooks to clean gradients
|
||||
for param in model.parameters():
|
||||
param.register_hook(grad_cleaner)
|
||||
|
||||
def input_trans(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
cuda_dtype = dtype if t.is_floating_point() else t.dtype
|
||||
cuda_t = t.data.clone()
|
||||
cuda_t = cuda_t.to(dtype=cuda_dtype, device='cuda')
|
||||
cuda_t.requires_grad = t.requires_grad
|
||||
return MTensor(cuda_t)
|
||||
return t
|
||||
|
||||
inp = tree_map(input_trans, inp)
|
||||
# reset all collected peak memory states
|
||||
MTensor.reset_peak_memory()
|
||||
before_cuda_alc = get_cuda_allocated()
|
||||
|
||||
step_fn(model, inp)
|
||||
|
||||
after_cuda_alc = MTensor.current_peak_memory()
|
||||
activation_occ = after_cuda_alc - before_cuda_alc
|
||||
|
||||
return dict(param_occ=memo_dict['param_occ'],
|
||||
buffer_occ=memo_dict['buffer_occ'],
|
||||
grad_occ=memo_dict['param_occ'],
|
||||
activation_occ=activation_occ)
|
|
@ -0,0 +1,99 @@
|
|||
import contextlib
|
||||
from typing import Iterator
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.elixir.tracer.utils import get_cuda_max_allocated
|
||||
|
||||
from .op_cache import wrapped_mm_ops
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
mm_ops_list = [aten.mm.default, aten.addmm.default, aten.bmm.default, aten.addbmm.default, aten.baddbmm.default]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch() -> Iterator[None]:
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
class MTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
peak_memory_allocated: int = 0
|
||||
|
||||
@staticmethod
|
||||
def reset_peak_memory():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
MTensor.peak_memory_allocated = 0
|
||||
|
||||
@staticmethod
|
||||
def update_peak_memory(new_peak):
|
||||
MTensor.peak_memory_allocated = max(MTensor.peak_memory_allocated, new_peak)
|
||||
|
||||
@staticmethod
|
||||
def current_peak_memory():
|
||||
cur_peak = get_cuda_max_allocated()
|
||||
return max(MTensor.peak_memory_allocated, cur_peak)
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
# TODO: clone strides and storage aliasing
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
requires_grad=elem.requires_grad)
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return f'MTensor({self.elem})'
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
def print_tensor(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
print(x.shape)
|
||||
|
||||
# tree_map(print_tensor, args)
|
||||
# tree_map(print_tensor, kwargs)
|
||||
|
||||
def unwrap(x):
|
||||
return x.elem if isinstance(x, MTensor) else x
|
||||
|
||||
def wrap(x):
|
||||
return MTensor(x) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
if func in mm_ops_list:
|
||||
res, pre_max = wrapped_mm_ops(func, *tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
MTensor.update_peak_memory(pre_max)
|
||||
else:
|
||||
with no_dispatch():
|
||||
res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
|
||||
outs = normalize_tuple(res)
|
||||
res = tree_map(wrap, outs)
|
||||
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
|
@ -0,0 +1,135 @@
|
|||
import contextlib
|
||||
from typing import Dict, Iterator, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
|
||||
|
||||
from .output_shape import addmm_output, bmm_output, check_cuda_mm, mm_output
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch() -> Iterator[None]:
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
def tensor_info(x: torch.Tensor):
|
||||
# returns the meta information used for CUDA kernels
|
||||
return (x.shape, x.stride(), x.layout, x.dtype)
|
||||
|
||||
|
||||
def get_args_info(*args):
|
||||
# returns a tuple contains the meta information of all inputs
|
||||
# every argument is expected to be a tensor
|
||||
info_list = []
|
||||
for x in args:
|
||||
if isinstance(x, torch.Tensor):
|
||||
info_list.append(tensor_info(x))
|
||||
return tuple(info_list)
|
||||
|
||||
|
||||
class OpCache(object):
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.temp_memory: Dict[Tuple, int] = dict()
|
||||
|
||||
def reset(self):
|
||||
self.temp_memory.clear()
|
||||
|
||||
def get(self, info):
|
||||
if info in self.temp_memory:
|
||||
return True, self.temp_memory[info]
|
||||
else:
|
||||
return False, None
|
||||
|
||||
def add(self, info, memo):
|
||||
self.temp_memory[info] = memo
|
||||
|
||||
def print(self):
|
||||
print(f'OpCache {self.name} information:')
|
||||
for k, v in self.temp_memory.items():
|
||||
print(f'key: {k}\ntemp_memo:{v}')
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
addmm_cache = OpCache('aten.addmm.default')
|
||||
bmm_cache = OpCache('aten.bmm.default')
|
||||
mm_cache = OpCache('aten.mm.default')
|
||||
|
||||
op_mapping = {
|
||||
aten.mm.default: {
|
||||
'cache': mm_cache,
|
||||
'output': mm_output
|
||||
},
|
||||
aten.addmm.default: {
|
||||
'cache': addmm_cache,
|
||||
'output': addmm_output
|
||||
},
|
||||
aten.bmm.default: {
|
||||
'cache': bmm_cache,
|
||||
'output': bmm_output
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def reset_caches():
|
||||
addmm_cache.reset()
|
||||
bmm_cache.reset()
|
||||
mm_cache.reset()
|
||||
|
||||
|
||||
def fake_cuda_output(temp_memo, output_shape, dtype):
|
||||
ret = torch.empty(output_shape, dtype=dtype, device='cuda')
|
||||
sub = temp_memo - ret.numel() * ret.element_size()
|
||||
|
||||
if sub > 0:
|
||||
# allocate a temp empty tensor block to simulate the computation in kernels
|
||||
temp = torch.empty(sub, dtype=torch.int8, device='cuda')
|
||||
# release this tensor block
|
||||
del temp
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def real_cuda_output(func, *args, **kwargs):
|
||||
cur_alc = get_cuda_allocated()
|
||||
# save the peak memory usage
|
||||
pre_max_alc = get_cuda_max_allocated()
|
||||
# the peak memory history is cleared here
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
with no_dispatch():
|
||||
ret = func(*args, **kwargs)
|
||||
|
||||
max_alc = get_cuda_max_allocated()
|
||||
# calculate the temporary memory allocation
|
||||
temp_memo = max_alc - cur_alc
|
||||
|
||||
return ret, temp_memo, pre_max_alc
|
||||
|
||||
|
||||
def wrapped_mm_ops(func, *args, **kwargs):
|
||||
check_cuda_mm(*args)
|
||||
|
||||
if func not in op_mapping:
|
||||
raise RuntimeError(f'Unsupported mm operation {func}')
|
||||
|
||||
args_info = get_args_info(*args)
|
||||
cache = op_mapping[func]['cache']
|
||||
cached_flag, temp_memo = cache.get(args_info)
|
||||
|
||||
if cached_flag:
|
||||
output_fn = op_mapping[func]['output']
|
||||
out_shape = output_fn(*args)
|
||||
ret = fake_cuda_output(temp_memo=temp_memo, output_shape=out_shape, dtype=args[0].dtype)
|
||||
return ret, 0
|
||||
else:
|
||||
ret, temp_memo, pre_max_alc = real_cuda_output(func, *args, **kwargs)
|
||||
cache.add(args_info, temp_memo)
|
||||
return ret, pre_max_alc
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Output functions come from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
|
||||
def check_cuda_mm(*args):
|
||||
for x in args:
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert x.device.type == 'cuda'
|
||||
|
||||
|
||||
def mm_output(a, b):
|
||||
assert a.dim() == 2, 'a must be 2D'
|
||||
assert b.dim() == 2, 'b must be 2D'
|
||||
N, M1 = a.shape
|
||||
M2, P = b.shape
|
||||
assert M1 == M2, 'a and b must have same reduction dim'
|
||||
return (N, P)
|
||||
|
||||
|
||||
def addmm_output(bias, x, y):
|
||||
return mm_output(x, y)
|
||||
|
||||
|
||||
def common_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
|
||||
assert batch1.dim() == 3, 'batch1 must be a 3D tensor'
|
||||
assert batch2.dim() == 3, 'batch2 must be a 3D tensor'
|
||||
|
||||
batch1_sizes = batch1.size()
|
||||
batch2_sizes = batch2.size()
|
||||
|
||||
bs = batch1_sizes[0]
|
||||
contraction_size = batch1_sizes[2]
|
||||
res_rows = batch1_sizes[1]
|
||||
res_cols = batch2_sizes[2]
|
||||
output_size = (bs, res_rows, res_cols)
|
||||
|
||||
assert batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size
|
||||
|
||||
if not is_bmm and self_baddbmm is not None:
|
||||
assert self_baddbmm.dim() == 3, 'self must be a 3D tensor'
|
||||
assert self_baddbmm.size() == output_size, \
|
||||
f'Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}'
|
||||
|
||||
return output_size
|
||||
|
||||
|
||||
def bmm_output(mat1, mat2):
|
||||
return common_baddbmm_bmm(mat1, mat2, True)
|
|
@ -0,0 +1,76 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
__all__ = [
|
||||
'TorchFactoryMethod', 'TorchOverrideableFactoryMethod', 'TorchNonOverrideableFactoryMethod', 'TensorPropertyMethod',
|
||||
'DistCommMethod', 'AliasATen', 'InplaceATen', 'MaybeInplaceAten', 'SameStorageAten'
|
||||
]
|
||||
|
||||
TorchOverrideableFactoryMethod = [
|
||||
'empty',
|
||||
'eye',
|
||||
'full',
|
||||
'ones',
|
||||
'rand',
|
||||
'randn',
|
||||
'zeros',
|
||||
]
|
||||
|
||||
TorchNonOverrideableFactoryMethod = [
|
||||
'arange',
|
||||
'finfo',
|
||||
'linspace',
|
||||
'logspace',
|
||||
'randint',
|
||||
'randperm',
|
||||
'tensor',
|
||||
]
|
||||
|
||||
TorchFactoryMethod = TorchOverrideableFactoryMethod + TorchNonOverrideableFactoryMethod
|
||||
|
||||
TensorPropertyMethod = ['dtype', 'shape', 'device', 'requires_grad', 'grad', 'grad_fn', 'data']
|
||||
|
||||
DistCommMethod = [
|
||||
'all_gather',
|
||||
'all_reduce',
|
||||
'all_to_all',
|
||||
'broadcast',
|
||||
'gather',
|
||||
'reduce',
|
||||
'reduce_scatter',
|
||||
'scatter',
|
||||
]
|
||||
|
||||
AliasATen = [
|
||||
aten.detach.default,
|
||||
aten.detach_.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
InplaceATen = [
|
||||
aten.add_.Tensor,
|
||||
aten.add_.Scalar,
|
||||
aten.sub_.Tensor,
|
||||
aten.sub_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul_.Scalar,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.pow_.Tensor,
|
||||
aten.pow_.Scalar,
|
||||
]
|
||||
|
||||
MaybeInplaceAten = [
|
||||
aten.diagonal.default,
|
||||
aten.select.int,
|
||||
aten.slice.Tensor,
|
||||
aten.as_strided.default,
|
||||
]
|
||||
|
||||
SameStorageAten = AliasATen + InplaceATen + MaybeInplaceAten
|
|
@ -0,0 +1,3 @@
|
|||
from .fx_order import generate_fx_order
|
||||
from .td_order import generate_td_order
|
||||
from .tf_order import generate_tf_order
|
|
@ -0,0 +1,62 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node, symbolic_trace
|
||||
|
||||
from colossalai.elixir.tracer.utils import meta_copy
|
||||
|
||||
|
||||
def generate_fx_order(model: nn.Module) -> List[Dict[str, nn.Parameter]]:
|
||||
fxf_name_mark = '_fxf_name'
|
||||
fxf_param_mark = '_fxf_param'
|
||||
|
||||
def tensor_trans(t):
|
||||
meta_t = t.data.to('meta')
|
||||
if isinstance(t, nn.Parameter):
|
||||
meta_t = nn.Parameter(meta_t)
|
||||
return meta_t
|
||||
|
||||
meta_model = meta_copy(model, tensor_trans)
|
||||
|
||||
# attach names for parameters
|
||||
for name, param in meta_model.named_parameters():
|
||||
setattr(param, fxf_name_mark, name)
|
||||
|
||||
fx_forward_order: List[Dict[str, nn.Parameter]] = list()
|
||||
|
||||
gm: GraphModule = symbolic_trace(meta_model)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op in ('output', 'placeholder'):
|
||||
continue
|
||||
|
||||
step_dict = None
|
||||
if node.op == 'get_attr':
|
||||
maybe_param = getattr(gm, node.target)
|
||||
# mark this node as a parameter
|
||||
if maybe_param is not None:
|
||||
setattr(node, fxf_param_mark, maybe_param)
|
||||
continue
|
||||
elif node.op == 'call_module':
|
||||
target_module = gm.get_submodule(node.target)
|
||||
step_dict = dict()
|
||||
# collect all parameters in the module
|
||||
for maybe_param in target_module.parameters():
|
||||
if maybe_param is not None:
|
||||
param_name = getattr(maybe_param, fxf_name_mark)
|
||||
step_dict[param_name] = maybe_param
|
||||
elif node.op in ('call_function', 'call_method'):
|
||||
step_dict = dict()
|
||||
for pre in node.args:
|
||||
if hasattr(pre, fxf_param_mark):
|
||||
param = getattr(pre, fxf_param_mark)
|
||||
param_name = getattr(param, fxf_name_mark)
|
||||
step_dict[param_name] = param
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported node op {node.op}!')
|
||||
|
||||
if step_dict is not None and len(step_dict) > 0:
|
||||
fx_forward_order.append(step_dict)
|
||||
|
||||
return fx_forward_order
|
|
@ -0,0 +1,156 @@
|
|||
import contextlib
|
||||
import uuid
|
||||
from typing import Callable, Dict, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.elixir.tracer.ops import SameStorageAten
|
||||
from colossalai.elixir.tracer.utils import meta_copy
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch() -> Iterator[None]:
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def register_storage(x):
|
||||
assert isinstance(x, nn.Parameter)
|
||||
assert x.data_ptr() == 0
|
||||
|
||||
data_ptr = uuid.uuid1()
|
||||
x.data_ptr = lambda: data_ptr
|
||||
|
||||
|
||||
class ATensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
data_ptr_dict: Dict[int, Tuple[str, nn.Parameter]] = None
|
||||
order_list: List[Dict] = None
|
||||
|
||||
@staticmethod
|
||||
def reset():
|
||||
ATensor.data_ptr_dict = dict()
|
||||
ATensor.order_list = list()
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
ATensor.data_ptr_dict = None
|
||||
ATensor.order_list = None
|
||||
|
||||
@staticmethod
|
||||
def add_data_ptr(name: str, param: nn.Parameter):
|
||||
data_ptr = param.data_ptr()
|
||||
if data_ptr not in ATensor.data_ptr_dict:
|
||||
ATensor.data_ptr_dict[data_ptr] = (name, param)
|
||||
else:
|
||||
name_in, param_in = ATensor.data_ptr_dict[data_ptr]
|
||||
if name != name_in or id(param) != id(param_in):
|
||||
raise RuntimeError('Got two different parameters with the same data ptr')
|
||||
|
||||
@staticmethod
|
||||
def get_param(data_ptr: int):
|
||||
if data_ptr in ATensor.data_ptr_dict:
|
||||
return ATensor.data_ptr_dict.get(data_ptr)
|
||||
else:
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
storage_offset=elem.storage_offset(),
|
||||
# TODO: clone strides and storage aliasing
|
||||
dtype=elem.dtype,
|
||||
layout=elem.layout,
|
||||
device=elem.device,
|
||||
requires_grad=elem.requires_grad)
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
return f'ATensor({self.elem})'
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
step_dict = dict()
|
||||
|
||||
def record_param(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
name, param = ATensor.get_param(x.data_ptr())
|
||||
if name is not None:
|
||||
step_dict[name] = param
|
||||
|
||||
def debug_tensor(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
print(type(x), x.shape, x.data_ptr(), id(x))
|
||||
if x.grad_fn:
|
||||
print(x.grad_fn)
|
||||
|
||||
tree_map(record_param, args)
|
||||
if len(step_dict) > 0:
|
||||
ATensor.order_list.append(step_dict)
|
||||
del step_dict
|
||||
|
||||
def unwrap(x):
|
||||
return x.elem if isinstance(x, ATensor) else x
|
||||
|
||||
def wrap(x):
|
||||
return ATensor(x) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
with no_dispatch():
|
||||
res = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
|
||||
outs = normalize_tuple(res)
|
||||
res = tree_map(wrap, outs)
|
||||
|
||||
if func in SameStorageAten:
|
||||
for x in res:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x.data_ptr = args[0].data_ptr
|
||||
|
||||
if len(res) == 1:
|
||||
return res[0]
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def generate_td_order(model: nn.Module, inp: Union[torch.Tensor, Tuple], step_fn: Callable):
|
||||
ATensor.reset()
|
||||
|
||||
def tensor_trans(t):
|
||||
meta_t = ATensor(t.data.to('meta'))
|
||||
if isinstance(t, nn.Parameter):
|
||||
meta_t = nn.Parameter(meta_t)
|
||||
return meta_t
|
||||
|
||||
model = meta_copy(model, tensor_trans)
|
||||
for name, param in model.named_parameters():
|
||||
register_storage(param)
|
||||
ATensor.add_data_ptr(name, param)
|
||||
|
||||
# convert all input data to meta_tensor
|
||||
if not isinstance(inp, tuple):
|
||||
inp = (inp,)
|
||||
inp = tree_map(lambda t: ATensor(torch.empty_like(t, device='meta', requires_grad=t.requires_grad)), inp)
|
||||
|
||||
step_fn(model, inp)
|
||||
|
||||
ret = ATensor.order_list
|
||||
ATensor.clear()
|
||||
|
||||
return ret
|
|
@ -0,0 +1,178 @@
|
|||
from typing import Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.elixir.tensor import is_no_hook_op
|
||||
from colossalai.elixir.tracer.utils import meta_copy
|
||||
from colossalai.elixir.utils import no_dispatch, normalize_tuple
|
||||
|
||||
torch_checkpoint_function = torch.utils.checkpoint.checkpoint
|
||||
|
||||
|
||||
def attach_checkpoint():
|
||||
default_in_checkpoint = False
|
||||
|
||||
def inner_checkpoint_function(function, *args, use_reentrant: bool = True, **kwargs):
|
||||
nonlocal default_in_checkpoint
|
||||
prev_in_checkpoint = default_in_checkpoint
|
||||
default_in_checkpoint = True
|
||||
# record the step where going into checkpoint
|
||||
if not prev_in_checkpoint:
|
||||
Record.record_in_checkpoint()
|
||||
# use original torch checkpoint function
|
||||
global torch_checkpoint_function
|
||||
ret = torch_checkpoint_function(function, *args, use_reentrant=use_reentrant, **kwargs)
|
||||
# roll back
|
||||
default_in_checkpoint = prev_in_checkpoint
|
||||
if not default_in_checkpoint:
|
||||
Record.record_out_checkpoint()
|
||||
return ret
|
||||
|
||||
torch.utils.checkpoint.checkpoint = inner_checkpoint_function
|
||||
|
||||
|
||||
def release_checkpoint():
|
||||
global torch_checkpoint_function
|
||||
torch.utils.checkpoint.checkpoint = torch_checkpoint_function
|
||||
|
||||
|
||||
class PostFwdPreBwd(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, params, *args):
|
||||
ctx.params = params
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads):
|
||||
Record.record_params(ctx.params)
|
||||
return (None, *grads)
|
||||
|
||||
|
||||
class Record(nn.Parameter):
|
||||
record_steps: List = None
|
||||
checkpoint_info: List = None
|
||||
in_checkpoint_step: int = -1
|
||||
|
||||
def __new__(cls, elem):
|
||||
assert elem.device.type == 'meta', f'device type: {elem.device.type}'
|
||||
r = torch.Tensor._make_subclass(cls, elem)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if is_no_hook_op(func):
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
params = list()
|
||||
|
||||
def append_param(x):
|
||||
if isinstance(x, nn.Parameter):
|
||||
assert isinstance(x, Record)
|
||||
params.append(x)
|
||||
|
||||
tree_map(append_param, args)
|
||||
tree_map(append_param, kwargs)
|
||||
Record.record_params(params)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = normalize_tuple(func(*args, **kwargs))
|
||||
ret = PostFwdPreBwd.apply(params, *ret)
|
||||
|
||||
def clone(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.clone()
|
||||
return t
|
||||
|
||||
ret = tree_map(clone, ret)
|
||||
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
else:
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
# notice: we should disable __torch_function__ here
|
||||
# otherwise, unexpected operations are called inside meta kernels
|
||||
with torch._C.DisableTorchFunction():
|
||||
with no_dispatch():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def reset():
|
||||
Record.record_steps = list()
|
||||
Record.checkpoint_info = list()
|
||||
Record.in_checkpoint_step = -1
|
||||
|
||||
@staticmethod
|
||||
def record_in_checkpoint():
|
||||
assert Record.in_checkpoint_step == -1
|
||||
Record.in_checkpoint_step = len(Record.record_steps)
|
||||
|
||||
@staticmethod
|
||||
def record_out_checkpoint():
|
||||
assert Record.in_checkpoint_step != -1
|
||||
value_pair = (Record.in_checkpoint_step, len(Record.record_steps))
|
||||
Record.checkpoint_info.append(value_pair)
|
||||
Record.in_checkpoint_step = -1
|
||||
|
||||
@staticmethod
|
||||
def steps():
|
||||
ret = dict(params_per_step=Record.record_steps, checkpoint_info=Record.checkpoint_info)
|
||||
Record.record_steps = None
|
||||
Record.checkpoint_info = None
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def record_params(params):
|
||||
record_dict = {p.param_name for p in params}
|
||||
Record.record_steps.append(record_dict)
|
||||
|
||||
|
||||
def generate_tf_order(model: nn.Module, inp: Dict, step_fn: Callable, dtype: torch.dtype = torch.float):
|
||||
assert isinstance(inp, dict), 'The example input should be a dictionary'
|
||||
|
||||
Record.reset()
|
||||
|
||||
def mtensor_trans(t: torch.Tensor):
|
||||
if t.is_floating_point():
|
||||
meta_dtype = dtype
|
||||
else:
|
||||
meta_dtype = t.dtype
|
||||
|
||||
meta_t = torch.empty_like(t, dtype=meta_dtype, device='meta')
|
||||
if isinstance(t, nn.Parameter):
|
||||
meta_t = Record(meta_t)
|
||||
meta_t.requires_grad = t.requires_grad
|
||||
return meta_t
|
||||
|
||||
model = meta_copy(model, mtensor_trans)
|
||||
for name, param in model.named_parameters():
|
||||
param.param_name = name
|
||||
|
||||
def input_trans(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
if t.is_floating_point():
|
||||
meta_dtype = dtype
|
||||
else:
|
||||
meta_dtype = t.dtype
|
||||
|
||||
meta_t = torch.empty_like(t, dtype=meta_dtype, device='meta', requires_grad=t.requires_grad)
|
||||
return meta_t
|
||||
return t
|
||||
|
||||
inp = tree_map(input_trans, inp)
|
||||
attach_checkpoint()
|
||||
step_fn(model, inp)
|
||||
release_checkpoint()
|
||||
ret = Record.steps()
|
||||
return ret
|
|
@ -0,0 +1,96 @@
|
|||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''):
|
||||
"""Get a dfs module list of the given module. Its order is same as the order of creations of modules.
|
||||
"""
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if module not in memo:
|
||||
for name, submodule in module._modules.items():
|
||||
if submodule is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ('.' if prefix else '') + name
|
||||
for m in _get_dfs_module_list(submodule, memo, submodule_prefix):
|
||||
yield m
|
||||
|
||||
memo.add(module)
|
||||
yield prefix, module
|
||||
|
||||
|
||||
def _get_shallow_copy_model(model: nn.Module):
|
||||
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
|
||||
But the new submodule and the old submodule share all attributes.
|
||||
"""
|
||||
old_to_new = dict()
|
||||
for name, module in _get_dfs_module_list(model):
|
||||
new_module = copy(module)
|
||||
new_module._modules = OrderedDict()
|
||||
for subname, submodule in module._modules.items():
|
||||
if submodule is None:
|
||||
continue
|
||||
setattr(new_module, subname, old_to_new[submodule])
|
||||
old_to_new[module] = new_module
|
||||
return old_to_new[model]
|
||||
|
||||
|
||||
def meta_copy(model: nn.Module, meta_fn: callable):
|
||||
new_model = _get_shallow_copy_model(model)
|
||||
old_parameters = dict()
|
||||
old_buffers = dict()
|
||||
|
||||
for (_, old_module), (_, new_module) in \
|
||||
zip(_get_dfs_module_list(model), _get_dfs_module_list(new_model)):
|
||||
|
||||
new_module._parameters = OrderedDict()
|
||||
for name, param in old_module._parameters.items():
|
||||
new_param = None
|
||||
if param is not None:
|
||||
param_id = id(param)
|
||||
if param_id in old_parameters:
|
||||
new_param = old_parameters.get(param_id)
|
||||
else:
|
||||
new_param = meta_fn(param)
|
||||
old_parameters[param_id] = new_param
|
||||
setattr(new_module, name, new_param)
|
||||
|
||||
new_module._buffers = OrderedDict()
|
||||
for name, buffer in old_module._buffers.items():
|
||||
new_buffer = None
|
||||
if buffer is not None:
|
||||
buffer_id = id(buffer)
|
||||
if buffer_id in old_buffers:
|
||||
new_buffer = old_buffers.get(buffer_id)
|
||||
else:
|
||||
new_buffer = meta_fn(buffer)
|
||||
old_buffers[buffer_id] = new_buffer
|
||||
new_module.register_buffer(name, new_buffer)
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
def get_cuda_allocated():
|
||||
return torch.cuda.memory_allocated()
|
||||
|
||||
|
||||
def get_cuda_max_allocated():
|
||||
return torch.cuda.max_memory_allocated()
|
||||
|
||||
|
||||
def model_memory_figure(model: nn.Module):
|
||||
param_occ = 0
|
||||
max_numel = 0
|
||||
for name, param in model.named_parameters():
|
||||
param_occ += param.numel() * param.element_size()
|
||||
max_numel = max(max_numel, param.numel())
|
||||
|
||||
buffer_occ = 0
|
||||
for name, buffer in model.named_buffers():
|
||||
buffer_occ += buffer.numel() * buffer.element_size()
|
||||
|
||||
return dict(param_occ=param_occ, param_max_numel=max_numel, buffer_occ=buffer_occ)
|
|
@ -0,0 +1,119 @@
|
|||
import contextlib
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_dispatch():
|
||||
guard = torch._C._DisableTorchDispatch()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
def normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
def seed_all(seed, cuda_deterministic=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if cuda_deterministic: # slower, more reproducible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
else:
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def init_distributed():
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
|
||||
init_method = f'tcp://[{host}]:{port}'
|
||||
dist.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available():
|
||||
# if local rank is not given, calculate automatically
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
seed_all(1024)
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs):
|
||||
if dist.is_initialized():
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
dist.barrier()
|
||||
else:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
total_numel = 0
|
||||
for module in model.modules():
|
||||
for p in module.parameters(recurse=False):
|
||||
total_numel += p.numel()
|
||||
return total_numel
|
||||
|
||||
|
||||
def model_size_formatter(numel: int) -> str:
|
||||
GB_SIZE = 10**9
|
||||
MB_SIZE = 10**6
|
||||
KB_SIZE = 10**3
|
||||
if numel >= GB_SIZE:
|
||||
return f'{numel / GB_SIZE:.1f}B'
|
||||
elif numel >= MB_SIZE:
|
||||
return f'{numel / MB_SIZE:.1f}M'
|
||||
elif numel >= KB_SIZE:
|
||||
return f'{numel / KB_SIZE:.1f}K'
|
||||
else:
|
||||
return str(numel)
|
||||
|
||||
|
||||
def calc_buffer_size(m: nn.Module, test_dtype: torch.dtype = torch.float):
|
||||
max_sum_size = 0
|
||||
for module in m.modules():
|
||||
sum_p_size = 0
|
||||
for param in module.parameters(recurse=False):
|
||||
assert param.dtype == test_dtype
|
||||
sum_p_size += param.numel()
|
||||
max_sum_size = max(max_sum_size, sum_p_size)
|
||||
return max_sum_size
|
||||
|
||||
|
||||
def calc_block_usage():
|
||||
snap_shot = torch.cuda.memory_snapshot()
|
||||
|
||||
total_sum = 0
|
||||
active_sum = 0
|
||||
for info_dict in snap_shot:
|
||||
blocks = info_dict.get('blocks')
|
||||
for b in blocks:
|
||||
size = b.get('size')
|
||||
state = b.get('state')
|
||||
total_sum += size
|
||||
if state == 'active_allocated':
|
||||
active_sum += size
|
||||
|
||||
active_ratio = 1
|
||||
if total_sum > 0:
|
||||
active_ratio = active_sum / total_sum
|
||||
|
||||
print(f'memory snap shot: active ratio {active_ratio:.2f}')
|
|
@ -0,0 +1,2 @@
|
|||
from .module import ElixirModule
|
||||
from .optimizer import ElixirOptimizer
|
|
@ -0,0 +1,344 @@
|
|||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Iterable, Mapping
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.elixir.chunk import Chunk, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.hook import BufferStore, HookParam
|
||||
from colossalai.elixir.search import SearchResult
|
||||
from colossalai.elixir.tensor import OutplaceTensor
|
||||
from colossalai.utils.model.experimental import LazyTensor
|
||||
|
||||
|
||||
def get_param_optim_data(param_data: torch.Tensor, param_dtype: torch.dtype):
|
||||
param_data = param_data.to(gpu_device())
|
||||
optim_data = param_data.clone() if param_data.dtype == torch.float else param_data.float()
|
||||
param_data = param_data.to(param_dtype)
|
||||
return param_data, optim_data
|
||||
|
||||
|
||||
class ElixirModule(nn.Module):
|
||||
"""Use this class to wrap your model when using Elixir. Don't know what should be written here.
|
||||
But some docstring is needed here.
|
||||
|
||||
args:
|
||||
module: training module
|
||||
search_result: a SearchResult generated from a search algorithm in `elixir.search`
|
||||
process_group: the communication group, ussually dp parallel group
|
||||
prefetch: whether to use prefetch overlaping communication with computation
|
||||
dtype: the dtype used in training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
search_result: SearchResult,
|
||||
process_group: ProcessGroup,
|
||||
prefetch: bool = False,
|
||||
dtype: torch.dtype = torch.float,
|
||||
reduce_always_fp32: bool = False,
|
||||
output_fp32: bool = False,
|
||||
use_fused_kernels: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
assert dtype in {torch.float, torch.float16}
|
||||
|
||||
self._set_module_outplace(module)
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.use_amp = (dtype == torch.float16)
|
||||
self.process_group = process_group
|
||||
self.prefetch_flag = prefetch
|
||||
self.reduce_always_fp32 = reduce_always_fp32
|
||||
self.output_fp32 = output_fp32
|
||||
self.use_fused_kernels = use_fused_kernels
|
||||
|
||||
self.no_grad_state_dict = dict()
|
||||
self.grad_state_dict = dict()
|
||||
self.__init_chunk_group(search_result)
|
||||
self.__init_chunk_fetcher(search_result, prefetch)
|
||||
self.__init_buffer_storage()
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
assert name in self.no_grad_state_dict
|
||||
continue
|
||||
assert name in self.grad_state_dict
|
||||
param.register_hook(partial(self._gradient_handler, param=param))
|
||||
param.__class__ = HookParam
|
||||
|
||||
def __init_chunk_group(self, sr: SearchResult):
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.module.state_dict(keep_vars=True)
|
||||
for name, tensor in state_dict.items():
|
||||
if isinstance(tensor, nn.Parameter):
|
||||
assert tensor.is_floating_point(), 'the dtypes of parameters should be float dtypes'
|
||||
# deal with parameters
|
||||
if tensor.requires_grad:
|
||||
self.grad_state_dict[name] = tensor
|
||||
else:
|
||||
self.no_grad_state_dict[name] = tensor
|
||||
# polish no-grad parameters
|
||||
tensor.data = tensor.data.to(dtype=self.dtype, device=gpu_device())
|
||||
else:
|
||||
# deal with buffers
|
||||
self._lazy_init_check(tensor)
|
||||
to_dtype = self.dtype if tensor.is_floating_point() else tensor.dtype
|
||||
tensor.data = tensor.data.to(dtype=to_dtype, device=gpu_device())
|
||||
|
||||
empty_mp = MemoryPool('cuda')
|
||||
empty_mp.allocate()
|
||||
|
||||
self.param_chunk_group = sr.chunk_group
|
||||
self.optim_chunk_group = ChunkGroup(empty_mp)
|
||||
self.param_to_optim = dict()
|
||||
vis_set = set()
|
||||
|
||||
for plan in sr.param_chunk_plans:
|
||||
assert plan.chunk_dtype == self.dtype
|
||||
# optimizer chunks should not be gathered
|
||||
optim_kwargs = copy(plan.kwargs)
|
||||
if 'rcache_fused' in optim_kwargs:
|
||||
optim_kwargs['rcache_fused'] = False
|
||||
|
||||
p_chunk = self.param_chunk_group.open_chunk(chunk_size=plan.chunk_size,
|
||||
chunk_dtype=plan.chunk_dtype,
|
||||
process_group=self.process_group,
|
||||
chunk_config=plan.kwargs)
|
||||
o_chunk = self.optim_chunk_group.open_chunk(chunk_size=plan.chunk_size,
|
||||
chunk_dtype=torch.float,
|
||||
process_group=self.process_group,
|
||||
chunk_config=optim_kwargs)
|
||||
|
||||
for name in plan.name_list:
|
||||
param = self.grad_state_dict[name]
|
||||
self._lazy_init_check(param)
|
||||
param_data, optim_data = get_param_optim_data(param.data, self.dtype)
|
||||
param.data = param_data
|
||||
p_chunk.append_tensor(param)
|
||||
o_chunk.append_tensor(optim_data)
|
||||
self.param_to_optim[param] = optim_data
|
||||
|
||||
vis_set.add(param)
|
||||
|
||||
self.param_chunk_group.close_chunk(p_chunk)
|
||||
self.optim_chunk_group.close_chunk(o_chunk)
|
||||
p_chunk.init_pair(o_chunk)
|
||||
|
||||
# sanity check: every parameter needed gradient has been initialized
|
||||
for param in self.module.parameters():
|
||||
if param.requires_grad:
|
||||
assert param in vis_set
|
||||
|
||||
def __init_chunk_fetcher(self, sr: SearchResult, prefetch: bool):
|
||||
scheduler = None
|
||||
if prefetch:
|
||||
assert sr.param_called_per_step is not None
|
||||
|
||||
chunk_called_per_step = list()
|
||||
for step in sr.param_called_per_step:
|
||||
step_set = set()
|
||||
for name in step:
|
||||
param = self.grad_state_dict[name]
|
||||
chunk = self.param_chunk_group.ten_to_chunk[param]
|
||||
step_set.add(chunk)
|
||||
chunk_called_per_step.append(step_set)
|
||||
|
||||
scheduler = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step)
|
||||
else:
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
self.fetcher = ChunkFetcher(scheduler,
|
||||
self.param_chunk_group,
|
||||
overlap=prefetch,
|
||||
reduce_always_fp32=self.reduce_always_fp32)
|
||||
self.fetcher.reset()
|
||||
|
||||
def __init_buffer_storage(self):
|
||||
buffer_size = 0
|
||||
for submodule in self.modules():
|
||||
sum_param_size = 0
|
||||
for param in submodule.parameters(recurse=False):
|
||||
if not param.requires_grad or self.fetcher.is_in_fused(param):
|
||||
continue
|
||||
assert param.dtype == self.dtype
|
||||
sum_param_size += param.numel()
|
||||
buffer_size = max(buffer_size, sum_param_size)
|
||||
self.buffer = BufferStore(buffer_size, self.dtype)
|
||||
print('module buffer', self.buffer)
|
||||
|
||||
def _gradient_handler(self, grad: torch.Tensor, param: nn.Parameter):
|
||||
# create an empty tensor
|
||||
fake_grad = self.buffer.empty_like(grad)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = self.fetcher.get_one_chunk(param)
|
||||
assert self.fetcher.group.is_accessed(chunk)
|
||||
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError()
|
||||
self.fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(param, grad)
|
||||
self.fetcher.reduce_chunk(chunk)
|
||||
|
||||
return fake_grad
|
||||
|
||||
def _lazy_init_check(self, tensor: torch.Tensor) -> None:
|
||||
if isinstance(tensor, LazyTensor):
|
||||
tensor.materialize()
|
||||
|
||||
def _set_module_outplace(self, m: nn.Module):
|
||||
# set inplace to False for all modules
|
||||
for module in m.modules():
|
||||
if hasattr(module, 'inplace'):
|
||||
module.inplace = False
|
||||
|
||||
def _deattach_fetcher(self):
|
||||
self.fetcher.clear()
|
||||
HookParam.release_fetcher()
|
||||
HookParam.disable_fused_kernel()
|
||||
|
||||
def _release_for_inference(self):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
scheduler = self.fetcher.scheduler
|
||||
param_group = self.param_chunk_group
|
||||
while True:
|
||||
maybe_chunk = scheduler.top()
|
||||
if maybe_chunk is None:
|
||||
break
|
||||
scheduler.remove(maybe_chunk)
|
||||
param_group.release_chunk(maybe_chunk)
|
||||
self._deattach_fetcher()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if torch.is_grad_enabled():
|
||||
inference_mode = False
|
||||
else:
|
||||
inference_mode = True
|
||||
|
||||
# reset the fetcher in this step
|
||||
self.fetcher.reset()
|
||||
HookParam.attach_fetcher(self.fetcher, self.buffer)
|
||||
if self.use_fused_kernels:
|
||||
HookParam.enable_fused_kernel()
|
||||
|
||||
def to_outplace_tensor(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
if t.is_floating_point():
|
||||
t = t.to(self.dtype)
|
||||
t = OutplaceTensor(t)
|
||||
return t
|
||||
|
||||
args = tree_map(to_outplace_tensor, args)
|
||||
kwargs = tree_map(to_outplace_tensor, kwargs)
|
||||
|
||||
outputs = self.module(*args, **kwargs)
|
||||
if self.output_fp32:
|
||||
outputs = outputs.float()
|
||||
|
||||
if inference_mode:
|
||||
self._release_for_inference()
|
||||
|
||||
return outputs
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss.backward()
|
||||
# reset the fetcher for the next step
|
||||
self._deattach_fetcher()
|
||||
# reset all attributes
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def state_dict(self,
|
||||
destination=None,
|
||||
prefix='',
|
||||
keep_vars=False,
|
||||
only_rank_0: bool = False,
|
||||
from_param: bool = False):
|
||||
assert keep_vars is False, 'state_dict can not keep variables in ElixirModule'
|
||||
# make sure that the variables are kept, we shall detach them later
|
||||
module_state_dict = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=True)
|
||||
|
||||
tensor_to_names = defaultdict(list)
|
||||
for name, tensor in module_state_dict.items():
|
||||
if isinstance(tensor, nn.Parameter) and tensor.requires_grad:
|
||||
used_tensor = self.grad_state_dict[name]
|
||||
if not from_param:
|
||||
used_tensor = self.param_to_optim.get(used_tensor)
|
||||
tensor_to_names[used_tensor].append(name)
|
||||
else:
|
||||
module_state_dict[name] = tensor.detach()
|
||||
|
||||
def update_state_dict(chunks: Iterable[Chunk]):
|
||||
for c in chunks:
|
||||
for op, cp in zip(c.get_tensors(), c.get_cpu_copy(only_rank_0)):
|
||||
for name in tensor_to_names.get(op):
|
||||
module_state_dict[name] = cp
|
||||
|
||||
if from_param:
|
||||
used_group = self.param_chunk_group
|
||||
else:
|
||||
used_group = self.optim_chunk_group
|
||||
|
||||
update_state_dict(used_group.fused_chunks)
|
||||
update_state_dict(used_group.float_chunks)
|
||||
|
||||
return module_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], only_rank_0: bool = False):
|
||||
load_flag = not only_rank_0 or dist.get_rank() == 0
|
||||
if not load_flag:
|
||||
# only rank 0 loads the state dict
|
||||
assert state_dict is None
|
||||
|
||||
if only_rank_0:
|
||||
# broadcast the length of the state dict
|
||||
state_length = len(state_dict) if load_flag else None
|
||||
comm_list = [state_length]
|
||||
dist.broadcast_object_list(comm_list)
|
||||
state_length = comm_list[0]
|
||||
# broadcast the keys of the state dict
|
||||
state_keys = state_dict.keys() if load_flag else [None] * state_length
|
||||
dist.broadcast_object_list(state_keys)
|
||||
# update the state dict
|
||||
if not load_flag:
|
||||
state_dict = {k: None for k in state_keys}
|
||||
|
||||
# init the mapping from optim tensor to load tensor
|
||||
optim_to_load = dict()
|
||||
for name, maybe_tensor in state_dict.items():
|
||||
if name in self.no_grad_state_dict:
|
||||
no_grad_param = self.no_grad_state_dict.get(name)
|
||||
if load_flag:
|
||||
no_grad_param.copy_(maybe_tensor)
|
||||
if only_rank_0:
|
||||
dist.broadcast(no_grad_param, src=0)
|
||||
elif name in self.grad_state_dict:
|
||||
grad_param = self.grad_state_dict.get(name)
|
||||
optim_tensor = self.param_to_optim.get(grad_param)
|
||||
optim_to_load[optim_tensor] = maybe_tensor
|
||||
|
||||
def use_state_dict(chunks: Iterable[Chunk]):
|
||||
for c in chunks:
|
||||
load_tensor_list = list()
|
||||
has_load = False
|
||||
for chunk_tensor in c.tensors_info.keys():
|
||||
if chunk_tensor in optim_to_load:
|
||||
has_load = True
|
||||
load_tensor_list.append(optim_to_load[chunk_tensor])
|
||||
else:
|
||||
load_tensor_list.append(None)
|
||||
if has_load:
|
||||
c.load_tensors(load_tensor_list, only_rank_0)
|
||||
c.paired_chunk.optim_update()
|
||||
|
||||
use_state_dict(self.optim_chunk_group.fused_chunks)
|
||||
use_state_dict(self.optim_chunk_group.float_chunks)
|
||||
|
||||
return
|
|
@ -0,0 +1,265 @@
|
|||
import math
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Parameter
|
||||
|
||||
import colossalai.nn.optimizer as colo_optim
|
||||
from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler, ConstantGradScaler, DynamicGradScaler
|
||||
from colossalai.elixir.chunk import Chunk
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.hook.storage import BufferStore
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .module import ElixirModule
|
||||
|
||||
_AVAIL_OPTIM_LIST = {colo_optim.FusedAdam, colo_optim.CPUAdam, colo_optim.HybridAdam}
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
|
||||
|
||||
class ElixirOptimizer(colo_optim.ColossalaiOptimizer):
|
||||
"""A wrapper for optimizers. Users should notice that one specific ElixirOptimizer is strictly
|
||||
corresponding to one ElixirModule. Currently only a group of optimizers are supported in ElixirOptimizer.
|
||||
The reason is that ElixirOptimizer only support element-wise optimizers now.
|
||||
We may enlarge the group of supported optimizers later.
|
||||
|
||||
Args:
|
||||
optim: The torch optimizer instance.
|
||||
module: The nn.Module instance wrapped as an ElixirModule.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: ElixirModule,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
initial_scale: float = 32768,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**24,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
init_step=False):
|
||||
|
||||
super().__init__(optimizer)
|
||||
assert isinstance(module, ElixirModule)
|
||||
self.scaled_optimizer = False
|
||||
if type(optimizer) in _AVAIL_OPTIM_LIST:
|
||||
self.scaled_optimizer = True
|
||||
|
||||
self.module = module
|
||||
self.param_chunk_group = module.param_chunk_group
|
||||
self.optim_chunk_group = module.optim_chunk_group
|
||||
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_optim_chunk: Dict[Parameter, Chunk] = dict()
|
||||
self.param_chunk_set: Set[Chunk] = self.param_chunk_group.fused_chunks.union(
|
||||
self.param_chunk_group.float_chunks)
|
||||
|
||||
self.clipping_flag = max_norm > 0.0
|
||||
self.max_norm = max_norm
|
||||
if self.clipping_flag:
|
||||
assert norm_type == 2.0, 'ElixirOptimizer only supports L2 norm now'
|
||||
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler: BaseGradScaler = None
|
||||
if module.use_amp:
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
else:
|
||||
self.grad_scaler = ConstantGradScaler(1.0, verbose=False)
|
||||
self._comm_buffer: torch.Tensor = torch.zeros(1, dtype=torch.float, device=gpu_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if init_step:
|
||||
# allocate memory before training
|
||||
self.__zero_step()
|
||||
|
||||
def __zero_step(self):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
cpu_buffer = BufferStore(self.module.buffer.buffer_size, self.module.buffer.buffer_dtype, 'cpu')
|
||||
buffer_dict = dict(cpu=cpu_buffer, cuda=self.module.buffer)
|
||||
for _, zero_buffer in buffer_dict.items():
|
||||
zero_buffer.zeros()
|
||||
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
optim_chunk = self.param_to_optim_chunk[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
|
||||
fake_param.data = buffer_dict.get(optim_chunk.shard_device.type).empty_1d(end - begin)
|
||||
fake_param.grad = fake_param.data
|
||||
fake_param.data = optim_chunk.shard[begin:end]
|
||||
|
||||
self.optim.step()
|
||||
self.zero_grad()
|
||||
self._update_fp16_params(update_flag=False)
|
||||
|
||||
def _set_grad_ptr(self):
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
optim_chunk = self.param_to_optim_chunk[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
param_chunk = optim_chunk.paired_chunk
|
||||
|
||||
fake_param.data = param_chunk.shard[begin:end]
|
||||
fake_param.grad = fake_param.data
|
||||
fake_param.data = optim_chunk.shard[begin:end]
|
||||
|
||||
def _update_fp16_params(self, update_flag: bool = True):
|
||||
none_tensor = torch.empty([0])
|
||||
for group in self.param_groups:
|
||||
for fake_param in group['params']:
|
||||
assert fake_param.grad is None
|
||||
fake_param.data = none_tensor.to(fake_param.device)
|
||||
|
||||
if update_flag:
|
||||
for param_chunk in self.param_chunk_set:
|
||||
param_chunk.optim_update()
|
||||
|
||||
def _check_overflow(self) -> bool:
|
||||
# calculate the overflow counter
|
||||
overflow_counter = 0
|
||||
for param_chunk in self.param_chunk_set:
|
||||
overflow_counter += int(param_chunk.overflow)
|
||||
return overflow_counter > 0
|
||||
|
||||
def _clear_optim_states(self) -> None:
|
||||
for param_chunk in self.param_chunk_set:
|
||||
param_chunk.overflow = False
|
||||
param_chunk.l2_norm = None
|
||||
|
||||
def _calc_global_norm(self) -> float:
|
||||
group_to_norm = defaultdict(float)
|
||||
for param_chunk in self.param_chunk_set:
|
||||
assert param_chunk.l2_norm is not None
|
||||
assert not param_chunk.is_replica
|
||||
|
||||
group_to_norm[param_chunk.torch_pg] += param_chunk.l2_norm
|
||||
|
||||
norm_sqr = 0.0
|
||||
for group, part_norm in group_to_norm.items():
|
||||
self._comm_buffer.fill_(part_norm)
|
||||
dist.all_reduce(self._comm_buffer, group=group)
|
||||
norm_sqr += self._comm_buffer.item()
|
||||
|
||||
global_norm = math.sqrt(norm_sqr)
|
||||
return global_norm
|
||||
|
||||
def _get_combined_scale(self):
|
||||
loss_scale = 1
|
||||
|
||||
assert self.optim_state == OptimState.SCALED
|
||||
loss_scale = self.loss_scale
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
combined_scale = loss_scale
|
||||
if self.clipping_flag:
|
||||
total_norm = self._calc_global_norm()
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * loss_scale
|
||||
|
||||
if combined_scale == 1:
|
||||
return -1
|
||||
else:
|
||||
return combined_scale
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._set_grad_ptr()
|
||||
found_inf = self._check_overflow()
|
||||
|
||||
if found_inf:
|
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._clear_optim_states() # clear chunk states used for optimizer update
|
||||
self.zero_grad() # reset all gradients
|
||||
self._update_fp16_params()
|
||||
return
|
||||
|
||||
# get combined scale. combined scale = loss scale * clipping norm
|
||||
# so that gradient = gradient / combined scale
|
||||
combined_scale = self._get_combined_scale()
|
||||
self.grad_scaler.update(found_inf)
|
||||
self._clear_optim_states()
|
||||
|
||||
if not self.scaled_optimizer:
|
||||
assert combined_scale == -1, 'You should use an optimizer in the available list:\n' \
|
||||
f'{_AVAIL_OPTIM_LIST}'
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
else:
|
||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||
|
||||
self.zero_grad()
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
raise NotImplementedError
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self.module.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
# This function is called except the last stage of pipeline parallel
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def __init__optimizer(self):
|
||||
|
||||
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
|
||||
param_info = local_chunk.tensors_info[local_param]
|
||||
begin = max(0, param_info.offset - local_chunk.shard_begin)
|
||||
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
|
||||
return begin, end
|
||||
|
||||
for group in self.param_groups:
|
||||
fake_params_list = list()
|
||||
|
||||
for param in group['params']:
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
param_chunk = self.module.fetcher.get_one_chunk(param)
|
||||
range_pair = get_range_pair(param_chunk, param)
|
||||
if range_pair[0] >= range_pair[1]:
|
||||
continue
|
||||
|
||||
grad_device = param_chunk.shard.device
|
||||
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
|
||||
self.param_to_optim_chunk[fake_param] = param_chunk.paired_chunk
|
||||
self.param_to_range[fake_param] = range_pair
|
||||
|
||||
fake_params_list.append(fake_param)
|
||||
|
||||
group['params'] = fake_params_list
|
|
@ -10,3 +10,5 @@ contexttimer
|
|||
ninja
|
||||
torch>=1.11
|
||||
safetensors
|
||||
sortedcontainers
|
||||
einops
|
||||
|
|
8
setup.py
8
setup.py
|
@ -16,7 +16,7 @@ from op_builder.utils import (
|
|||
|
||||
try:
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CppExtension
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
|
@ -30,7 +30,11 @@ BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1
|
|||
IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1
|
||||
|
||||
# a variable to store the op builder
|
||||
ext_modules = []
|
||||
ext_modules = [
|
||||
CppExtension(name='colossalai.elixir.simulator',
|
||||
sources=['colossalai/elixir/simulator.cpp'],
|
||||
extra_compile_args=['-O3', '-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'])
|
||||
]
|
||||
|
||||
# we do not support windows currently
|
||||
if sys.platform == 'win32':
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.elixir import ElixirModule, ElixirOptimizer
|
||||
from colossalai.elixir.search import minimum_waste_search
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def check_elixir_compatibility(early_stop: bool = True):
|
||||
"""check gemini plugin over model zoo
|
||||
|
||||
Args:
|
||||
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
|
||||
"""
|
||||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
# These models lead to CUDA error
|
||||
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
|
||||
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext',
|
||||
'torchaudio_wav2vec2_base', 'torchaudio_hubert_base', 'torchvision_convnext_base'):
|
||||
continue
|
||||
|
||||
try:
|
||||
print(name)
|
||||
global_size = dist.get_world_size()
|
||||
global_group = dist.GroupMember.WORLD
|
||||
|
||||
model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
criterion = lambda x: x.mean()
|
||||
data = data_gen_fn()
|
||||
|
||||
data = {
|
||||
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
sr = minimum_waste_search(
|
||||
# pre-commit: do not rearrange
|
||||
m=model,
|
||||
group_size=global_size,
|
||||
unified_dtype=torch.float16,
|
||||
prefetch=False,
|
||||
verbose=True)
|
||||
|
||||
model = ElixirModule(model, sr, global_group, prefetch=False, dtype=torch.float16)
|
||||
optimizer = ElixirOptimizer(model, optimizer, initial_scale=32)
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
optimizer.backward(loss)
|
||||
optimizer.step()
|
||||
passed_models.append(name)
|
||||
|
||||
del model, optimizer, criterion, data, output, loss
|
||||
except Exception as e:
|
||||
failed_info[name] = e
|
||||
if early_stop:
|
||||
raise e
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f'Passed models({len(passed_models)}): {passed_models}\n\n')
|
||||
print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n')
|
||||
assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()])
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, early_stop: bool = True):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_elixir_compatibility(early_stop=early_stop)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def exam_compatibility(early_stop: bool = True):
|
||||
spawn(run_dist, 2, early_stop=early_stop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exam_compatibility(early_stop=False)
|
|
@ -0,0 +1,72 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.chunk.scheduler import FIFOScheduler
|
||||
from colossalai.elixir.hook import BufferStore, HookParam
|
||||
from colossalai.elixir.tensor import OutplaceTensor
|
||||
|
||||
|
||||
def to_divide(a: int, b: int):
|
||||
return a + (-a % b)
|
||||
|
||||
|
||||
def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
empty_grad.storage().resize_(0)
|
||||
|
||||
with torch._C.DisableTorchFunction():
|
||||
chunk = fetcher.get_one_chunk(param)
|
||||
if chunk.tensors_info[param].state != TensorState.HOLD_AFTER_BWD:
|
||||
raise RuntimeError()
|
||||
fetcher.group.tensor_trans_state(param, TensorState.READY_FOR_REDUCE)
|
||||
chunk.copy_tensor_to_chunk_slice(param, grad)
|
||||
fetcher.reduce_chunk(chunk)
|
||||
|
||||
return empty_grad
|
||||
|
||||
|
||||
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
|
||||
pg_size = dist.get_world_size(process_group)
|
||||
|
||||
private_list = list()
|
||||
for param in model.parameters():
|
||||
block_size = to_divide(param.numel(), pg_size)
|
||||
private_list.append(BlockRequire(block_size, param.dtype))
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(private_block_list=private_list)
|
||||
cg = ChunkGroup(rcache=mp)
|
||||
# allocate chunk group
|
||||
fused_config = dict(rcache_fused=True)
|
||||
for param in model.parameters():
|
||||
cg.allocate_chunk([param], to_divide(param.numel(), pg_size), param.dtype, process_group, fused_config)
|
||||
# initialize chunk fetcher
|
||||
scheduler = FIFOScheduler()
|
||||
fetcher = ChunkFetcher(scheduler, cg)
|
||||
buffer = BufferStore(0, torch.float32)
|
||||
# register fetcher and gradient handler
|
||||
HookParam.attach_fetcher(fetcher, buffer)
|
||||
for param in model.parameters():
|
||||
param.register_hook(partial(grad_handler, param=param, fetcher=fetcher))
|
||||
param.__class__ = HookParam
|
||||
# set inplace to False for all modules
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'inplace'):
|
||||
module.inplace = False
|
||||
|
||||
def transform_input(self_module, inputs):
|
||||
fetcher.reset()
|
||||
input_list = list()
|
||||
for t in inputs:
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = OutplaceTensor(t)
|
||||
input_list.append(t)
|
||||
return tuple(input_list)
|
||||
|
||||
model.register_forward_pre_hook(transform_input)
|
||||
|
||||
return model, cg
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_block():
|
||||
b = PublicBlock(123, torch.float16, 'cuda')
|
||||
payload_b = b.payload
|
||||
|
||||
assert payload_b.numel() == 123
|
||||
assert payload_b.dtype == torch.float16
|
||||
assert payload_b.device.type == 'cuda'
|
||||
assert payload_b.numel() * payload_b.element_size() == b.memo_occ
|
||||
|
||||
c = PrivateBlock(77, torch.float, 'cpu')
|
||||
payload_c = c.payload
|
||||
|
||||
assert payload_c.numel() == 77
|
||||
assert payload_c.dtype == torch.float
|
||||
assert payload_c.device.type == 'cpu'
|
||||
assert payload_c.numel() * payload_c.element_size() == c.memo_occ
|
||||
|
||||
print('test_block: ok')
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_memory_pool():
|
||||
mp = MemoryPool(device_type='cuda')
|
||||
private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)]
|
||||
mp.allocate(public_block_number=4, private_block_list=private_list)
|
||||
|
||||
block0 = mp.get_public_block()
|
||||
|
||||
assert block0 in mp.public_used_blocks
|
||||
assert mp.public_used_cnt == 1
|
||||
assert mp.public_free_cnt == 3
|
||||
|
||||
block1 = mp.get_public_block()
|
||||
|
||||
assert block1 in mp.public_used_blocks
|
||||
assert mp.public_used_cnt == 2
|
||||
assert mp.public_free_cnt == 2
|
||||
|
||||
mp.free_public_block(block0)
|
||||
mp.free_public_block(block1)
|
||||
|
||||
assert block0 in mp.public_free_blocks
|
||||
assert block1 in mp.public_free_blocks
|
||||
assert mp.public_used_cnt == 0
|
||||
assert mp.public_free_cnt == 4
|
||||
|
||||
block0 = mp.get_private_block(5, torch.float)
|
||||
assert block0.numel == 5
|
||||
assert block0.dtype == torch.float
|
||||
|
||||
print('test_memory_pool: ok')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_block()
|
||||
test_memory_pool()
|
|
@ -0,0 +1,155 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def exam_chunk_functions(nproc, group):
|
||||
a = torch.randn(2, 64, device='cuda')
|
||||
copy_a = a.clone()
|
||||
b = torch.randn(2, 2, 128, device='cuda')
|
||||
copy_b = b.clone()
|
||||
c = torch.randn(128, device='cuda')
|
||||
copy_c = c.clone()
|
||||
d = torch.randn(4, 32, device='cuda')
|
||||
copy_d = d.clone()
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_number=1)
|
||||
|
||||
chunk = Chunk(mp, 1024, torch.float, group)
|
||||
chunk.l2_norm_flag = True
|
||||
assert chunk.chunk_size == 1024
|
||||
assert chunk.chunk_dtype == torch.float
|
||||
assert chunk.shard_size == 1024 // nproc
|
||||
|
||||
def check_tensors():
|
||||
assert torch.equal(a, copy_a)
|
||||
assert torch.equal(b, copy_b)
|
||||
assert torch.equal(c, copy_c)
|
||||
assert torch.equal(d, copy_d)
|
||||
|
||||
chunk.append_tensor(a)
|
||||
chunk.append_tensor(b)
|
||||
chunk.append_tensor(c)
|
||||
chunk.append_tensor(d)
|
||||
check_tensors()
|
||||
|
||||
chunk.close_chunk()
|
||||
assert chunk.is_replica is False
|
||||
# check function: get_cpu_copy
|
||||
cpu_copys = chunk.get_cpu_copy()
|
||||
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
|
||||
assert t_cpu.device.type == 'cpu'
|
||||
assert torch.equal(t_gpu.cpu(), t_cpu)
|
||||
# check function: access_chunk
|
||||
block = mp.get_public_block()
|
||||
chunk.access_chunk(block)
|
||||
assert chunk.is_replica
|
||||
assert chunk.scatter_check
|
||||
check_tensors()
|
||||
# check function: release_chunk
|
||||
chunk.optim_sync_flag = False
|
||||
block = chunk.release_chunk()
|
||||
assert block in mp.public_used_blocks
|
||||
assert chunk.is_replica is False
|
||||
assert chunk.optim_sync_flag is True
|
||||
# check function: access_chunk after release_chunk
|
||||
chunk.access_chunk(block)
|
||||
check_tensors()
|
||||
# check function: reduce_chunk
|
||||
norm = block.payload.float().norm(2)**2
|
||||
chunk.reduce_chunk()
|
||||
assert chunk.is_replica is False
|
||||
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||
|
||||
test_norm = torch.Tensor([chunk.l2_norm]).cuda()
|
||||
dist.all_reduce(test_norm)
|
||||
assert torch.allclose(norm, test_norm)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print('chunk functions are ok')
|
||||
|
||||
|
||||
def exam_chunk_states(nproc, group):
|
||||
a = torch.randn(2, 64, device='cuda')
|
||||
copy_a = a.clone()
|
||||
b = torch.randn(2, 2, 128, device='cuda')
|
||||
copy_b = b.clone()
|
||||
c = torch.randn(128, device='cuda')
|
||||
copy_c = c.clone()
|
||||
d = torch.randn(4, 32, device='cuda')
|
||||
copy_d = d.clone()
|
||||
|
||||
private = [BlockRequire(1024, torch.float)]
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(private_block_list=private)
|
||||
|
||||
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
|
||||
assert chunk.chunk_size == 1024
|
||||
assert chunk.chunk_dtype == torch.float
|
||||
assert chunk.shard_size == 1024 // nproc
|
||||
|
||||
def check_tensors():
|
||||
assert torch.equal(a, copy_a)
|
||||
assert torch.equal(b, copy_b)
|
||||
assert torch.equal(c, copy_c)
|
||||
assert torch.equal(d, copy_d)
|
||||
|
||||
chunk.append_tensor(a)
|
||||
chunk.append_tensor(b)
|
||||
chunk.append_tensor(c)
|
||||
chunk.append_tensor(d)
|
||||
check_tensors()
|
||||
|
||||
chunk.close_chunk()
|
||||
assert chunk.is_replica is False
|
||||
|
||||
chunk.access_chunk()
|
||||
assert chunk.is_replica
|
||||
check_tensors()
|
||||
|
||||
assert chunk.tensor_state_cnter[TensorState.HOLD] == 4
|
||||
chunk.tensor_trans_state(a, TensorState.COMPUTE)
|
||||
assert chunk.tensor_state_cnter[TensorState.HOLD] == 3
|
||||
assert chunk.tensor_state_cnter[TensorState.COMPUTE] == 1
|
||||
|
||||
tensor_list = [a, b, c, d]
|
||||
for t in tensor_list:
|
||||
chunk.tensor_trans_state(t, TensorState.COMPUTE)
|
||||
chunk.tensor_trans_state(t, TensorState.HOLD_AFTER_BWD)
|
||||
chunk.tensor_trans_state(t, TensorState.READY_FOR_REDUCE)
|
||||
assert chunk.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == 4
|
||||
assert chunk.reduce_check
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print('chunk states are ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_functions(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_functions(world_size=4)
|
|
@ -0,0 +1,71 @@
|
|||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.chunk import ChunkGroup
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def check_gradient(ddp_model, my_model, cg: ChunkGroup):
|
||||
for chunk in cg.fused_chunks:
|
||||
cg.access_chunk(chunk)
|
||||
|
||||
for (name, p0), p1 in zip(ddp_model.named_parameters(), my_model.parameters()):
|
||||
torch.cuda.synchronize()
|
||||
print(f'checking parameter {name}')
|
||||
assert_close(p0.grad.data, p1.data)
|
||||
|
||||
|
||||
def exam_chunk_fetcher(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
torch_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(torch_model)
|
||||
|
||||
rank = dist.get_rank(group)
|
||||
# get different data
|
||||
seed_all(1001 + rank)
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
seed_all(1001, cuda_deterministic=True)
|
||||
ddp_model = DDP(torch_model)
|
||||
ddp_loss = ddp_model(**data)
|
||||
ddp_loss.backward()
|
||||
|
||||
hook_model, cg = hook_transform(test_model, group)
|
||||
my_loss = hook_model(**data)
|
||||
my_loss.backward()
|
||||
|
||||
assert_close(ddp_loss, my_loss)
|
||||
check_gradient(ddp_model, hook_model, cg)
|
||||
print('private chunk fetcher is ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_fetcher(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_fetcher(world_size=2)
|
|
@ -0,0 +1,98 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def exam_chunk_group_functions(nproc, group):
|
||||
a = torch.randn(3, 64, device='cuda')
|
||||
copy_a = a.clone()
|
||||
b = torch.randn(2, 32, device='cuda')
|
||||
copy_b = b.clone()
|
||||
c = torch.randn(256, device='cuda')
|
||||
copy_c = c.clone()
|
||||
d = torch.randn(2, 2, 64, device='cuda')
|
||||
copy_d = d.clone()
|
||||
e = torch.randn(2, 33, device='cuda')
|
||||
copy_e = e.clone()
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
|
||||
cg = ChunkGroup(rcache=mp)
|
||||
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
|
||||
c1 = cg.allocate_chunk([c], 256, torch.float, group)
|
||||
c2 = cg.allocate_chunk([d], 256, torch.float, group)
|
||||
|
||||
fused_config = dict(rcache_fused=True)
|
||||
c3 = cg.allocate_chunk([e], 68, torch.float, group, fused_config)
|
||||
|
||||
def check_chunk_0():
|
||||
assert torch.equal(a, copy_a)
|
||||
assert torch.equal(b, copy_b)
|
||||
|
||||
def check_chunk_1():
|
||||
assert torch.equal(c, copy_c)
|
||||
|
||||
def check_chunk_2():
|
||||
assert torch.equal(d, copy_d)
|
||||
|
||||
def check_chunk_3():
|
||||
assert torch.equal(e, copy_e)
|
||||
|
||||
# check tensors_to_chunks
|
||||
chunks = cg.tensors_to_chunks([e, a])
|
||||
assert chunks[0] == c0
|
||||
assert chunks[1] == c3
|
||||
# check access_chunk for unfused chunks
|
||||
cg.access_chunk(c0)
|
||||
cg.access_chunk(c1)
|
||||
check_chunk_0()
|
||||
check_chunk_1()
|
||||
assert not cg.rcache_enough_check(c2)
|
||||
assert cg.rcache_enough_check(c3)
|
||||
# check access_chunk for fused chunks
|
||||
cg.access_chunk(c3)
|
||||
check_chunk_3()
|
||||
# check release_chunk for unfused chunks
|
||||
cg.release_chunk(c1)
|
||||
assert cg.rcache_enough_check(c2)
|
||||
# check access_chunk
|
||||
cg.access_chunk(c2)
|
||||
check_chunk_2()
|
||||
|
||||
cg.tensor_trans_state(e, TensorState.COMPUTE)
|
||||
cg.tensor_trans_state(e, TensorState.HOLD_AFTER_BWD)
|
||||
cg.tensor_trans_state(e, TensorState.READY_FOR_REDUCE)
|
||||
cg.reduce_chunk(c3)
|
||||
assert not c3.is_replica
|
||||
|
||||
torch.cuda.synchronize()
|
||||
print('chunk group functions are ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_group(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_group(world_size=2)
|
|
@ -0,0 +1,130 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import Chunk, MemoryPool
|
||||
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def exam_fifo(nproc, group):
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_number=1)
|
||||
c0 = Chunk(mp, 1024, torch.float, group)
|
||||
c1 = Chunk(mp, 1024, torch.float, group)
|
||||
c2 = Chunk(mp, 1024, torch.float, group)
|
||||
|
||||
sdl = FIFOScheduler()
|
||||
sdl.reset()
|
||||
|
||||
sdl.add(c0)
|
||||
sdl.add(c1)
|
||||
sdl.add(c2)
|
||||
sdl.add(c0) # nothing happens here
|
||||
assert sdl.top() == c0
|
||||
|
||||
sdl.remove(c0)
|
||||
assert sdl.top() == c1, f'{sdl.top()}'
|
||||
sdl.remove(c0)
|
||||
assert sdl.top() == c1, f'{sdl.top()}'
|
||||
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c1
|
||||
sdl.remove(c1)
|
||||
assert sdl.top() == c2
|
||||
sdl.remove(c2)
|
||||
assert sdl.top() == c0
|
||||
|
||||
|
||||
def exam_prefetch(nproc, group):
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate()
|
||||
c0 = Chunk(mp, 1024, torch.float, group)
|
||||
c1 = Chunk(mp, 1024, torch.float, group)
|
||||
c2 = Chunk(mp, 1024, torch.float, group)
|
||||
|
||||
chunk_called_per_step = [[c0], [c1], [c2], [c0], [c0], [c1], [c2], [c2], [c1], [c0]]
|
||||
|
||||
sdl = PrefetchScheduler(chunk_called_per_step=chunk_called_per_step)
|
||||
print(sdl.next_step_dict)
|
||||
sdl.reset()
|
||||
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c0
|
||||
|
||||
sdl.step()
|
||||
sdl.add(c1)
|
||||
assert sdl.top() == c1
|
||||
|
||||
sdl.step()
|
||||
sdl.add(c2)
|
||||
assert sdl.top() == c2
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c2
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c0
|
||||
sdl.remove(c0) # notice here
|
||||
|
||||
sdl.remove(c1)
|
||||
sdl.step()
|
||||
sdl.add(c1)
|
||||
assert sdl.top() == c1
|
||||
|
||||
sdl.remove(c2)
|
||||
sdl.step()
|
||||
sdl.add(c2)
|
||||
assert sdl.top() == c1
|
||||
|
||||
sdl.remove(c2)
|
||||
sdl.step()
|
||||
sdl.add(c2)
|
||||
assert sdl.top() == c2
|
||||
sdl.remove(c2) # notice here
|
||||
sdl.add(c0) # notice here
|
||||
|
||||
sdl.remove(c1)
|
||||
sdl.step()
|
||||
sdl.add(c1)
|
||||
assert sdl.top() == c1
|
||||
sdl.remove(c1) # notice here
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.step()
|
||||
sdl.add(c0)
|
||||
assert sdl.top() == c0
|
||||
|
||||
sdl.remove(c0)
|
||||
sdl.clear()
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_scheduler(world_size=1):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_scheduler()
|
|
@ -0,0 +1,22 @@
|
|||
from colossalai.elixir.ctx import MetaContext
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_meta_context():
|
||||
builder, *_ = TEST_MODELS.get('resnet')
|
||||
with MetaContext():
|
||||
model = builder()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == 'meta'
|
||||
print(name, param)
|
||||
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == 'meta'
|
||||
print(name, buffer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_meta_context()
|
|
@ -0,0 +1,56 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.hook import BufferStore, HookParam
|
||||
from colossalai.elixir.tensor import FakeTensor
|
||||
|
||||
|
||||
def test_hook():
|
||||
x = nn.Parameter(torch.randn(4, 4))
|
||||
|
||||
ori_numel = x.numel()
|
||||
ori_size = x.size()
|
||||
ori_stride = x.stride()
|
||||
ori_offset = x.storage_offset()
|
||||
|
||||
fake_data = FakeTensor(x.data)
|
||||
x.data = fake_data
|
||||
x.__class__ = HookParam
|
||||
|
||||
assert x.numel() == ori_numel
|
||||
assert x.size() == ori_size
|
||||
assert x.stride() == ori_stride
|
||||
assert x.storage_offset() == ori_offset
|
||||
|
||||
|
||||
def test_store():
|
||||
buffer = BufferStore(1024, torch.float16)
|
||||
print(buffer)
|
||||
|
||||
x = torch.randn(4, 128, dtype=torch.float16, device='cuda')
|
||||
original_ptr_x = x.data_ptr()
|
||||
copy_x = deepcopy(x)
|
||||
|
||||
y = torch.randn(512, dtype=torch.float16, device='cuda')
|
||||
original_ptr_y = y.data_ptr()
|
||||
copy_y = deepcopy(y)
|
||||
|
||||
offset = 0
|
||||
offset = buffer.insert(x, offset)
|
||||
assert offset == x.numel()
|
||||
assert torch.equal(x, copy_x)
|
||||
|
||||
offset = buffer.insert(y, offset)
|
||||
assert offset == 1024
|
||||
assert torch.equal(y, copy_y)
|
||||
|
||||
buffer.erase(x)
|
||||
buffer.erase(y)
|
||||
assert x.data_ptr() == original_ptr_x
|
||||
assert y.data_ptr() == original_ptr_y
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_store()
|
|
@ -0,0 +1,35 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
from torch.testing import assert_close
|
||||
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def exam_one_model(model_fn, data_fn):
|
||||
from colossalai.elixir.kernels.attn_wrapper import wrap_attention
|
||||
|
||||
torch_model = model_fn().cuda()
|
||||
test_model = deepcopy(torch_model)
|
||||
test_model = wrap_attention(test_model)
|
||||
|
||||
data = to_cuda(data_fn())
|
||||
torch_out = torch_model(**data)
|
||||
torch_out.backward()
|
||||
|
||||
test_out = test_model(**data)
|
||||
test_out.backward()
|
||||
|
||||
assert_close(torch_out, test_out)
|
||||
for (name, p_torch), p_test in zip(torch_model.named_parameters(), test_model.parameters()):
|
||||
assert_close(p_torch.grad, p_test.grad)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to install xformers")
|
||||
def test_gpt_atten_kernel():
|
||||
exam_one_model(*TEST_MODELS.get('gpt2_micro'))
|
||||
exam_one_model(*TEST_MODELS.get('opt_micro'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gpt_atten_kernel()
|
|
@ -0,0 +1,58 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.elixir.wrapper import ElixirModule
|
||||
|
||||
|
||||
def exam_fused_layernorm(nproc, group):
|
||||
torch_model = nn.LayerNorm(2048)
|
||||
fused_model = deepcopy(torch_model)
|
||||
|
||||
torch_model = torch_model.cuda()
|
||||
sr = simple_search(fused_model, nproc, 1, 1.0, verbose=True)
|
||||
fused_model = ElixirModule(fused_model, sr, group, use_fused_kernels=True)
|
||||
|
||||
data = torch.randn(2, 2048, device='cuda')
|
||||
|
||||
torch_loss = torch_model(data).sum()
|
||||
torch_loss.backward()
|
||||
|
||||
fused_loss = fused_model(data).sum()
|
||||
fused_model.backward(fused_loss)
|
||||
|
||||
assert_close(torch_loss, fused_loss)
|
||||
|
||||
grad_state = fused_model.state_dict(from_param=True)
|
||||
for name, param in torch_model.named_parameters():
|
||||
assert_close(param.grad.cpu(), grad_state[name])
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_fused_layernorm(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1])
|
||||
@pytest.mark.skip(reason='need to install apex')
|
||||
def test_fused_layernorm(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fused_layernorm(world_size=1)
|
|
@ -0,0 +1,38 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import minimum_waste_search
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
def step_fn(model, inp):
|
||||
model(**inp).backward()
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_mini_waste_search():
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
sr = minimum_waste_search(model,
|
||||
1,
|
||||
unified_dtype=torch.float16,
|
||||
cpu_offload=True,
|
||||
prefetch=True,
|
||||
verbose=True,
|
||||
inp=data,
|
||||
step_fn=step_fn)
|
||||
|
||||
chunk_plans = deepcopy(sr.param_chunk_plans)
|
||||
for plan in chunk_plans:
|
||||
assert plan.chunk_dtype == torch.float16
|
||||
assert plan.kwargs.get('shard_device') == torch.device('cpu')
|
||||
assert plan.kwargs.get('cpu_pin_memory') == True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mini_waste_search()
|
|
@ -0,0 +1,30 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import optimal_search
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
def step_fn(model, inp):
|
||||
model(**inp).backward()
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_optimal_search():
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_small')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
sr = optimal_search(model, 1, unified_dtype=torch.float16, overlap=True, verbose=True, inp=data, step_fn=step_fn)
|
||||
|
||||
chunk_plans = deepcopy(sr.param_chunk_plans)
|
||||
for plan in chunk_plans:
|
||||
assert plan.chunk_dtype == torch.float16
|
||||
assert plan.kwargs.get('shard_device') == gpu_device()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_optimal_search()
|
|
@ -0,0 +1,48 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
def step_fn(model, inp):
|
||||
model(**inp).backward()
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_simple_search():
|
||||
model_fn, data_fn = TEST_MODELS.get('small')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
sr = simple_search(model,
|
||||
1,
|
||||
split_number=5,
|
||||
shard_device=gpu_device(),
|
||||
prefetch=True,
|
||||
verbose=True,
|
||||
inp=data,
|
||||
step_fn=step_fn)
|
||||
|
||||
chunk_plans = deepcopy(sr.param_chunk_plans)
|
||||
private_plan = chunk_plans.pop(0)
|
||||
assert private_plan.name_list == ['embed.weight']
|
||||
assert private_plan.chunk_size == 320
|
||||
assert private_plan.kwargs.get('shard_device') == gpu_device()
|
||||
|
||||
assert chunk_plans[0].name_list == ['norm1.weight', 'norm1.bias']
|
||||
assert chunk_plans[1].name_list == ['mlp.proj1.weight', 'mlp.proj1.bias']
|
||||
assert chunk_plans[2].name_list == ['mlp.proj2.weight', 'mlp.proj2.bias']
|
||||
assert chunk_plans[3].name_list == ['norm2.weight']
|
||||
assert chunk_plans[4].name_list == ['norm2.bias']
|
||||
|
||||
for plan in chunk_plans:
|
||||
assert plan.chunk_size == 1088
|
||||
assert plan.kwargs.get('shard_device') == gpu_device()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_simple_search()
|
|
@ -0,0 +1,13 @@
|
|||
from colossalai.elixir.simulator import move_count
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_move_count():
|
||||
steps = [[0], [1, 2], [3], [3], [1, 2], [0]]
|
||||
size = 2
|
||||
assert move_count(steps, size) == 12
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_move_count()
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import to_cuda
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_registry():
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
for name, model_tuple in TEST_MODELS:
|
||||
torch.cuda.synchronize()
|
||||
print(f'model `{name}` is in testing')
|
||||
|
||||
model_fn, data_fn = model_tuple
|
||||
model = model_fn().cuda()
|
||||
data = to_cuda(data_fn())
|
||||
loss = model(**data)
|
||||
loss.backward()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_registry()
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.tracer.memory_tracer import cuda_memory_profiling
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def one_step(model, inp):
|
||||
loss = model(**inp)
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
|
||||
def try_one_model(model_fn, data_fn):
|
||||
model = model_fn().cuda()
|
||||
data = to_cuda(data_fn())
|
||||
one_step(model, data) # generate gradients
|
||||
|
||||
pre_cuda_alc = torch.cuda.memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
one_step(model, data)
|
||||
aft_cuda_alc = torch.cuda.max_memory_allocated()
|
||||
torch_activation_occ = aft_cuda_alc - pre_cuda_alc
|
||||
model.zero_grad(set_to_none=True)
|
||||
print('normal', torch_activation_occ)
|
||||
|
||||
before = torch.cuda.memory_allocated()
|
||||
profiling_dict = cuda_memory_profiling(model, data, one_step)
|
||||
after = torch.cuda.memory_allocated()
|
||||
print('profiling', profiling_dict)
|
||||
assert before == after
|
||||
assert torch_activation_occ == profiling_dict['activation_occ']
|
||||
print('Check is ok.')
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_cuda_profiler():
|
||||
model_list = ['resnet', 'gpt2_micro']
|
||||
for name in model_list:
|
||||
model_fn, data_fn = TEST_MODELS.get(name)
|
||||
try_one_model(model_fn, data_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cuda_profiler()
|
|
@ -0,0 +1,145 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.elixir.tracer.memory_tracer import MTensor
|
||||
from colossalai.elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache
|
||||
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
def op_mm(x, y):
|
||||
u = torch.matmul(x, y)
|
||||
return u.shape
|
||||
|
||||
|
||||
def op_addmm(x, y, z):
|
||||
u = torch.addmm(x, y, z)
|
||||
return u.shape
|
||||
|
||||
|
||||
def op_bmm(x, y):
|
||||
u = torch.bmm(x, y)
|
||||
return u.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_mm(dtype, size0=(4, 256), size1=(256, 1024)):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
assert get_cuda_allocated() == 0
|
||||
|
||||
x = torch.randn(size0, dtype=dtype, device='cuda')
|
||||
y = torch.randn(size1, dtype=dtype, device='cuda')
|
||||
torch_pre_alc = get_cuda_allocated()
|
||||
|
||||
torch_z_size = op_mm(x, y)
|
||||
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
||||
|
||||
del x
|
||||
del y
|
||||
|
||||
assert get_cuda_allocated() == 0
|
||||
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
||||
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
||||
op1_pre_alc = get_cuda_allocated()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op1_z_size = op_mm(x, y)
|
||||
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op1_z_size
|
||||
assert torch_pre_alc == op1_pre_alc
|
||||
assert torch_temp_alc == op1_temp_alc
|
||||
assert len(mm_cache.temp_memory) > 0
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op2_z_size = op_mm(x, y)
|
||||
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op2_z_size
|
||||
assert torch_temp_alc == op2_temp_alc
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_addmm(dtype, size0=(4, 16), size1=(16, 64)):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
assert get_cuda_allocated() == 0
|
||||
|
||||
x = torch.randn(size0, dtype=dtype, device='cuda')
|
||||
y = torch.randn(size1, dtype=dtype, device='cuda')
|
||||
u = torch.randn(size1[-1], dtype=dtype, device='cuda')
|
||||
torch_pre_alc = get_cuda_allocated()
|
||||
|
||||
torch_z_size = op_addmm(u, x, y)
|
||||
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
||||
|
||||
del x
|
||||
del y
|
||||
del u
|
||||
|
||||
assert get_cuda_allocated() == 0
|
||||
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
||||
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
||||
u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda'))
|
||||
op1_pre_alc = get_cuda_allocated()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op1_z_size = op_addmm(u, x, y)
|
||||
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op1_z_size
|
||||
assert torch_pre_alc == op1_pre_alc
|
||||
assert torch_temp_alc == op1_temp_alc
|
||||
assert len(addmm_cache.temp_memory) > 0
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op2_z_size = op_addmm(u, x, y)
|
||||
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op2_z_size
|
||||
assert torch_temp_alc == op2_temp_alc
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
assert get_cuda_allocated() == 0
|
||||
|
||||
x = torch.randn(size0, dtype=dtype, device='cuda')
|
||||
y = torch.randn(size1, dtype=dtype, device='cuda')
|
||||
torch_pre_alc = get_cuda_allocated()
|
||||
|
||||
torch_z_size = op_bmm(x, y)
|
||||
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
||||
|
||||
del x
|
||||
del y
|
||||
|
||||
assert get_cuda_allocated() == 0
|
||||
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
||||
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
||||
op1_pre_alc = get_cuda_allocated()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op1_z_size = op_bmm(x, y)
|
||||
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op1_z_size
|
||||
assert torch_pre_alc == op1_pre_alc
|
||||
assert torch_temp_alc == op1_temp_alc
|
||||
assert len(bmm_cache.temp_memory) > 0
|
||||
|
||||
bmm_cache.print()
|
||||
|
||||
MTensor.reset_peak_memory()
|
||||
op2_z_size = op_bmm(x, y)
|
||||
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
||||
|
||||
assert torch_z_size == op2_z_size
|
||||
assert torch_temp_alc == op2_temp_alc
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_addmm(dtype=torch.float)
|
|
@ -0,0 +1,37 @@
|
|||
from colossalai.elixir.tracer.param_tracer import generate_tf_order
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_tf_forward_backward():
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
|
||||
model = model_fn()
|
||||
data = data_fn()
|
||||
|
||||
def forward_backward_fn(local_model, local_input):
|
||||
local_model(**local_input).backward()
|
||||
|
||||
# model.gradient_checkpointing_enable()
|
||||
tf_order = generate_tf_order(model, data, forward_backward_fn)
|
||||
params_per_step = tf_order['params_per_step']
|
||||
assert len(params_per_step) == 32
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
tf_order = generate_tf_order(model, data, forward_backward_fn)
|
||||
params_per_step = tf_order['params_per_step']
|
||||
checkpoint_info = tf_order['checkpoint_info']
|
||||
for i, step in enumerate(params_per_step):
|
||||
print(f'step {i}: {step}')
|
||||
for c in checkpoint_info:
|
||||
print(f'checkpoint info: {c}')
|
||||
assert len(params_per_step) == 44
|
||||
|
||||
assert data['input_ids'].device.type == 'cpu'
|
||||
assert data['attention_mask'].device.type == 'cpu'
|
||||
for param in model.parameters():
|
||||
assert param.device.type == 'cpu'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tf_forward_backward()
|
|
@ -0,0 +1,94 @@
|
|||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from apex import amp
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def amp_check_model_states(ddp_optim, test_model):
|
||||
test_states = test_model.state_dict()
|
||||
for (name, _), p in zip(test_model.module.named_parameters(), amp.master_params(ddp_optim)):
|
||||
test_p = test_states[name]
|
||||
copy_p = p.to(test_p.device)
|
||||
print(f'checking parameter `{name}`: {test_p.dtype} {copy_p.dtype}')
|
||||
assert_close(test_p.data, copy_p.data)
|
||||
|
||||
|
||||
def exam_amp_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
# important here, since apex has a lazy fp32 init after the first optimizer step
|
||||
test_model = test_model.half()
|
||||
|
||||
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
ddp_model, ddp_optim = amp.initialize(ddp_model,
|
||||
ddp_optim,
|
||||
opt_level='O2',
|
||||
loss_scale=1.0,
|
||||
keep_batchnorm_fp32=False)
|
||||
ddp_model = DDP(ddp_model, message_size=0, allreduce_always_fp32=True)
|
||||
print("ok")
|
||||
exit(0)
|
||||
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
sr = simple_search(test_model, nproc, shard_device=gpu_device(), unified_dtype=torch.float16, verbose=True)
|
||||
test_model = ElixirModule(test_model, sr, group, dtype=torch.float16, reduce_always_fp32=True, output_fp32=True)
|
||||
test_optim = ElixirOptimizer(test_model, test_optim, initial_scale=1.0)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group), cuda_deterministic=True)
|
||||
for _ in range(2):
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
ddp_optim.zero_grad()
|
||||
ddp_loss = ddp_model(**data)
|
||||
with amp.scale_loss(ddp_loss, ddp_optim) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
ddp_optim.step()
|
||||
|
||||
test_optim.zero_grad()
|
||||
test_loss = test_model(**data)
|
||||
test_optim.backward(test_loss)
|
||||
test_optim.step()
|
||||
|
||||
assert_close(ddp_loss, test_loss)
|
||||
amp_check_model_states(ddp_optim, test_model)
|
||||
|
||||
|
||||
def exam_amp_in_models(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
|
||||
exam_amp_one_model(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_amp_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_elixir_amp(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_elixir_amp(world_size=2)
|
|
@ -0,0 +1,95 @@
|
|||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, assert_dict_values, to_cuda
|
||||
|
||||
|
||||
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
|
||||
grad_state = test_model.state_dict(from_param=True)
|
||||
for name, param in ddp_model.named_parameters():
|
||||
assert_close(param.grad.cpu(), grad_state[name])
|
||||
|
||||
|
||||
def exam_module_init(nproc, group, grad_flag):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
torch_model = model_fn().cuda()
|
||||
test_model = model_fn().cuda()
|
||||
|
||||
for p1, p2 in zip(torch_model.parameters(), test_model.parameters()):
|
||||
p1.requires_grad = p2.requires_grad = grad_flag
|
||||
|
||||
sr = simple_search(test_model, nproc)
|
||||
model = ElixirModule(test_model, sr, group)
|
||||
# check function: ElixirModule.load_state_dict after ElixirModule.__init__
|
||||
torch_st = torch_model.state_dict()
|
||||
if dist.get_rank() != 0:
|
||||
torch_st = None
|
||||
test_st = model.load_state_dict(torch_st, only_rank_0=True)
|
||||
# check function: ElixirModule.state_dict after ElixirModule.__init__
|
||||
torch_st = torch_model.state_dict()
|
||||
test_st = model.state_dict()
|
||||
assert_dict_values(torch_st, test_st, fn=torch.equal)
|
||||
|
||||
|
||||
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2261):
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
sr = simple_search(test_model, nproc, allocate_factor=0.6)
|
||||
test_model = ElixirModule(test_model, sr, group)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group))
|
||||
data = data_fn()
|
||||
data = to_cuda(data)
|
||||
|
||||
seed_all(exam_seed, cuda_deterministic=True)
|
||||
ddp_model = DDP(ddp_model)
|
||||
ddp_loss = ddp_model(**data)
|
||||
ddp_loss.backward()
|
||||
|
||||
test_loss = test_model(**data)
|
||||
test_model.backward(test_loss)
|
||||
|
||||
assert_close(ddp_loss, test_loss)
|
||||
check_gradient(ddp_model.module, test_model)
|
||||
|
||||
|
||||
def exam_modules_fwd_bwd(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=False)
|
||||
exam_module_init(nproc=world_size, group=dist.GroupMember.WORLD, grad_flag=True)
|
||||
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_elixir_module(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_elixir_module(world_size=2)
|
|
@ -0,0 +1,77 @@
|
|||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule, ElixirOptimizer
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, allclose, assert_dict_values, to_cuda
|
||||
|
||||
|
||||
def exam_optimizer_one_model(model_fn, data_fn, nproc, group, exam_seed=2261):
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
|
||||
ddp_model = DDP(ddp_model)
|
||||
ddp_optim = HybridAdam(ddp_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
|
||||
test_optim = HybridAdam(test_model.parameters(), lr=1e-1, weight_decay=0)
|
||||
sr = simple_search(test_model, nproc, shard_device=gpu_device())
|
||||
test_model = ElixirModule(test_model, sr, group)
|
||||
test_optim = ElixirOptimizer(test_model, test_optim)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group))
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
seed_all(exam_seed, cuda_deterministic=True)
|
||||
ddp_optim.zero_grad()
|
||||
ddp_loss = ddp_model(**data)
|
||||
ddp_loss.backward()
|
||||
ddp_optim.step()
|
||||
|
||||
test_optim.zero_grad()
|
||||
test_loss = test_model(**data)
|
||||
test_optim.backward(test_loss)
|
||||
test_optim.step()
|
||||
|
||||
assert_close(ddp_loss, test_loss)
|
||||
torch_st = ddp_model.module.state_dict()
|
||||
test_st = test_model.state_dict()
|
||||
assert_dict_values(torch_st, test_st, fn=partial(allclose, rtol=2e-6, atol=2e-5))
|
||||
|
||||
|
||||
def exam_optimizer_in_models(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
exam_optimizer_one_model(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_optimizer_in_models(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_elixir_optimizer(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_elixir_optimizer(world_size=4)
|
|
@ -0,0 +1,89 @@
|
|||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.search import simple_search
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.elixir.wrapper import ElixirModule
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
||||
def check_gradient(ddp_model: nn.Module, test_model: ElixirModule):
|
||||
grad_state = test_model.state_dict(from_param=True)
|
||||
for name, param in ddp_model.named_parameters():
|
||||
assert_close(param.grad.cpu(), grad_state[name])
|
||||
|
||||
|
||||
def exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group, exam_seed=2263):
|
||||
|
||||
def one_step(local_model, local_input):
|
||||
loss = local_model(**local_input)
|
||||
loss.backward()
|
||||
return loss
|
||||
|
||||
ddp_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(ddp_model)
|
||||
|
||||
# get different data
|
||||
seed_all(exam_seed + dist.get_rank(group))
|
||||
data = to_cuda(data_fn())
|
||||
|
||||
# wrap as DDP model
|
||||
ddp_model = DDP(ddp_model)
|
||||
# search how to initialize chunks
|
||||
sr = simple_search(test_model,
|
||||
nproc,
|
||||
shard_device=gpu_device(),
|
||||
prefetch=True,
|
||||
verbose=True,
|
||||
inp=data,
|
||||
step_fn=one_step)
|
||||
test_model = ElixirModule(test_model, sr, group, prefetch=True)
|
||||
|
||||
seed_all(exam_seed, cuda_deterministic=True)
|
||||
ddp_loss = one_step(ddp_model, data)
|
||||
|
||||
with torch.no_grad():
|
||||
test_loss = test_model(**data)
|
||||
assert_close(ddp_loss, test_loss)
|
||||
|
||||
test_loss = test_model(**data)
|
||||
test_model.backward(test_loss)
|
||||
assert_close(ddp_loss, test_loss)
|
||||
check_gradient(ddp_model.module, test_model)
|
||||
|
||||
|
||||
def exam_modules_fwd_bwd(nproc, group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
exam_one_module_fwd_bwd(model_fn, data_fn, nproc, group)
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_modules_fwd_bwd(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_module_prefetch(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_module_prefetch(world_size=2)
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
from torch.testing import assert_close
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from . import gpt, mlp, opt, resnet, small
|
||||
from .registry import TEST_MODELS
|
||||
|
||||
|
||||
def to_cuda(input_dict):
|
||||
|
||||
def local_fn(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.cuda()
|
||||
return t
|
||||
|
||||
ret = tree_map(local_fn, input_dict)
|
||||
return ret
|
||||
|
||||
|
||||
def allclose(ta, tb, **kwargs):
|
||||
assert_close(ta, tb, **kwargs)
|
||||
return True
|
||||
|
||||
|
||||
def assert_dict_keys(test_dict, keys):
|
||||
assert len(test_dict) == len(keys)
|
||||
for k in keys:
|
||||
assert k in test_dict
|
||||
|
||||
|
||||
def assert_dict_values(da, db, fn):
|
||||
assert len(da) == len(db)
|
||||
for k, v in da.items():
|
||||
assert k in db
|
||||
if not torch.is_tensor(v):
|
||||
continue
|
||||
u = db.get(k)
|
||||
if u.device != v.device:
|
||||
v = v.to(u.device)
|
||||
# print(f"checking key {k}: {u.shape} vs {v.shape}")
|
||||
assert fn(u.data, v.data), f'max diff {torch.max(torch.abs(u.data - v.data))}'
|
|
@ -0,0 +1,79 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
MICRO_VS = 128
|
||||
MICRO_BS = 4
|
||||
MICRO_SL = 64
|
||||
|
||||
MACRO_VS = 50257
|
||||
MACRO_BS = 2
|
||||
MACRO_SL = 1024
|
||||
|
||||
|
||||
def micro_data_fn():
|
||||
input_ids = torch.randint(low=0, high=MICRO_VS, size=(MICRO_BS, MICRO_SL))
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return dict(input_ids=input_ids, attention_mask=attn_mask)
|
||||
|
||||
|
||||
def small_data_fn():
|
||||
input_ids = torch.randint(low=0, high=MACRO_VS, size=(MACRO_BS, MACRO_SL))
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return dict(input_ids=input_ids, attention_mask=attn_mask)
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257):
|
||||
super().__init__()
|
||||
self.enable_gc = False
|
||||
self.config = GPT2Config(
|
||||
# pre-commit: do not rearrange
|
||||
n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size,
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0)
|
||||
self.module = GPT2LMHeadModel(config=self.config)
|
||||
self.criterion = GPTLMLoss()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
self.module.gradient_checkpointing_enable()
|
||||
self.enable_gc = True
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
output = self.module(input_ids=input_ids, attention_mask=attention_mask, use_cache=(not self.enable_gc))[0]
|
||||
loss = self.criterion(output, input_ids)
|
||||
return loss
|
||||
|
||||
|
||||
gpt2_micro = partial(GPTLMModel, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128)
|
||||
gpt2_small = GPTLMModel
|
||||
gpt2_base = partial(GPTLMModel, hidden_size=1024, num_layers=24, num_attention_heads=16)
|
||||
|
||||
TEST_MODELS.register('gpt2_micro', gpt2_micro, micro_data_fn)
|
||||
TEST_MODELS.register('gpt2_small', gpt2_small, small_data_fn)
|
||||
TEST_MODELS.register('gpt2_base', gpt2_base, small_data_fn)
|
|
@ -0,0 +1,34 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
|
||||
def mlp_data_fn():
|
||||
return dict(x=torch.randn(4, 16))
|
||||
|
||||
|
||||
class MlpModule(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.proj1 = nn.Linear(hidden_dim, 4 * hidden_dim)
|
||||
self.act = nn.GELU()
|
||||
self.proj2 = nn.Linear(4 * hidden_dim, hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return x + (self.proj2(self.act(self.proj1(x))))
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.mlp = MlpModule(hidden_dim)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.mlp(x)
|
||||
return output.sum()
|
||||
|
||||
|
||||
TEST_MODELS.register('mlp', MlpModel, mlp_data_fn)
|
|
@ -0,0 +1,46 @@
|
|||
import torch.nn as nn
|
||||
from transformers import OPTConfig, OPTForCausalLM
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
from .gpt import micro_data_fn
|
||||
|
||||
|
||||
class OPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.module = OPTForCausalLM(config=config)
|
||||
self.enable_gc = False
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
self.module.gradient_checkpointing_enable()
|
||||
self.enable_gc = True
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
loss = self.module(
|
||||
# pre-commit: do not rearrange
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=input_ids,
|
||||
use_cache=(not self.enable_gc))['loss']
|
||||
return loss
|
||||
|
||||
|
||||
def opt_micro():
|
||||
opt_config = OPTConfig(
|
||||
# pre-commit: do not rearrange
|
||||
vocab_size=128,
|
||||
activation_dropout=0.0,
|
||||
dropout=0,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=4,
|
||||
ffn_dim=128,
|
||||
num_attention_heads=4,
|
||||
word_embed_proj_dim=32,
|
||||
output_projection=True)
|
||||
return OPTLMModel(opt_config)
|
||||
|
||||
|
||||
TEST_MODELS.register('opt_micro', opt_micro, micro_data_fn)
|
|
@ -0,0 +1,26 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class Registry(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._registry_dict = OrderedDict()
|
||||
|
||||
def register(self, name: str, model_fn: Callable, data_fn: Callable):
|
||||
assert name not in self._registry_dict
|
||||
|
||||
model_tuple = (model_fn, data_fn)
|
||||
self._registry_dict[name] = model_tuple
|
||||
|
||||
def get(self, name: str):
|
||||
return self._registry_dict[name]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._registry_dict.items())
|
||||
|
||||
|
||||
TEST_MODELS = Registry()
|
||||
|
||||
__all__ = [TEST_MODELS]
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
|
||||
def resnet_data_fn():
|
||||
return dict(x=torch.randn(4, 3, 32, 32))
|
||||
|
||||
|
||||
class ResNetModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.r = resnet18()
|
||||
|
||||
def forward(self, x):
|
||||
output = self.r(x)
|
||||
return output.sum()
|
||||
|
||||
|
||||
TEST_MODELS.register('resnet', ResNetModel, resnet_data_fn)
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from tests.test_elixir.utils.mlp import MlpModule
|
||||
from tests.test_elixir.utils.registry import TEST_MODELS
|
||||
|
||||
|
||||
def small_data_fn():
|
||||
return dict(x=torch.randint(low=0, high=20, size=(4, 8)))
|
||||
|
||||
|
||||
class SmallModel(nn.Module):
|
||||
|
||||
def __init__(self, num_embeddings: int = 20, hidden_dim: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(num_embeddings, hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim)
|
||||
self.mlp = MlpModule(hidden_dim=hidden_dim)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim)
|
||||
self.proj = nn.Linear(hidden_dim, num_embeddings, bias=False)
|
||||
self.proj.weight = self.embed.weight
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embed(x)
|
||||
x = x + self.norm1(self.mlp(x))
|
||||
x = self.proj(self.norm2(x))
|
||||
x = x.mean(dim=-2)
|
||||
return x.sum()
|
||||
|
||||
|
||||
TEST_MODELS.register('small', SmallModel, small_data_fn)
|
Loading…
Reference in New Issue